Skip to content

Commit

Permalink
Make Serializer.dumps() require a body parameter.
Browse files Browse the repository at this point in the history
When caching permanent redirects, if `body` is left to `None`, there's an infinite recursion that will lead to the caching to silently fail and not cache anything at all.

So instead, make `body` a required parameter, which can be empty (`''`) for cached redirects.
  • Loading branch information
Flameeyes committed Apr 13, 2020
1 parent 91b3e11 commit 6b45d11
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 27 deletions.
14 changes: 8 additions & 6 deletions cachecontrol/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import calendar
import time
import traceback
from email.utils import parsedate_tz

from requests.structures import CaseInsensitiveDict
Expand Down Expand Up @@ -280,7 +281,7 @@ def cache_response(self, request, response, body=None, status_codes=None):
cc = self.parse_cache_control(response_headers)

cache_url = self.cache_url(request.url)
logger.debug('Updating cache %r with response from "%s"', self.cache, cache_url)
logger.debug('Updating cache with response from "%s"', cache_url)

# Delete it from the cache if we happen to have it stored there
no_store = False
Expand Down Expand Up @@ -309,14 +310,14 @@ def cache_response(self, request, response, body=None, status_codes=None):
if self.cache_etags and "etag" in response_headers:
logger.debug("Caching due to etag")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url, self.serializer.dumps(request, response, body)
)

# Add to the cache any permanent redirects. We do this before looking
# that the Date headers.
elif int(response.status) in PERMANENT_REDIRECT_STATUSES:
logger.debug("Caching permanent redirect")
self.cache.set(cache_url, self.serializer.dumps(request, response))
self.cache.set(cache_url, self.serializer.dumps(request, response, b''))

# Add to the cache if the response headers demand it. If there
# is no date header then we can't do anything about expiring
Expand All @@ -329,7 +330,7 @@ def cache_response(self, request, response, body=None, status_codes=None):
if "max-age" in cc and cc["max-age"] > 0:
logger.debug("Caching b/c date exists and max-age > 0")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url, self.serializer.dumps(request, response, body)
)

# If the request can expire, it means we should cache it
Expand All @@ -338,7 +339,7 @@ def cache_response(self, request, response, body=None, status_codes=None):
if response_headers["expires"]:
logger.debug("Caching b/c of expires header")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url, self.serializer.dumps(request, response, body)
)
else:
logger.debug("No combination of headers to cache.")
Expand Down Expand Up @@ -379,6 +380,7 @@ def update_cached_response(self, request, response):
cached_response.status = 200

# update our cache
self.cache.set(cache_url, self.serializer.dumps(request, cached_response))
body = cached_response.read(decode_content=False)
self.cache.set(cache_url, self.serializer.dumps(request, cached_response, body))

return cached_response
20 changes: 4 additions & 16 deletions cachecontrol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,14 @@ def _b64_decode_str(s):
return _b64_decode_bytes(s).decode("utf8")


_default_body_read = object()


class Serializer(object):

def dumps(self, request, response, body=None):
def dumps(self, request, response, body):
response_headers = CaseInsensitiveDict(response.headers)

if body is None:
body = response.read(decode_content=False)

# NOTE: 99% sure this is dead code. I'm only leaving it
# here b/c I don't have a test yet to prove
# it. Basically, before using
# `cachecontrol.filewrapper.CallbackFileWrapper`,
# this made an effort to reset the file handle. The
# `CallbackFileWrapper` short circuits this code by
# setting the body as the content is consumed, the
# result being a `body` argument is *always* passed
# into cache_response, and in turn,
# `Serializer.dump`.
response._fp = io.BytesIO(body)

# NOTE: This is all a bit weird, but it's really important that on
# Python 2.x these objects are unicode and not str, even when
# they contain only ascii. The problem here is that msgpack
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_no_cache_with_wrong_sized_body(self, cc):
# When the body is the wrong size, then we don't want to cache it
# because it is obviously broken.
resp = self.resp({"cache-control": "max-age=3600", "Content-Length": "5"})
cc.cache_response(self.req(), resp, body=b"0" * 10)
cc.cache_response(self.req(), resp, b"0" * 10)

assert not cc.cache.set.called

Expand All @@ -82,7 +82,7 @@ def test_cache_response_cache_max_age(self, cc):
resp = self.resp({"cache-control": "max-age=3600", "date": now})
req = self.req()
cc.cache_response(req, resp)
cc.serializer.dumps.assert_called_with(req, resp, body=None)
cc.serializer.dumps.assert_called_with(req, resp, None)
cc.cache.set.assert_called_with(self.url, ANY)

def test_cache_response_cache_max_age_with_invalid_value_not_cached(self, cc):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_read_latest_version_streamable(self, url):
original_resp = requests.get(url, stream=True)
req = original_resp.request

resp = self.serializer.loads(req, self.serializer.dumps(req, original_resp.raw))
resp = self.serializer.loads(req, self.serializer.dumps(req, original_resp.raw, original_resp.content))

assert resp.read()

Expand All @@ -99,7 +99,7 @@ def test_read_latest_version(self, url):
req = original_resp.request

resp = self.serializer.loads(
req, self.serializer.dumps(req, original_resp.raw, body=data)
req, self.serializer.dumps(req, original_resp.raw, data)
)

assert resp.read() == data
Expand All @@ -114,5 +114,5 @@ def test_no_vary_header(self, url):
original_resp.raw.headers["vary"] = "Foo"

assert self.serializer.loads(
req, self.serializer.dumps(req, original_resp.raw, body=data)
req, self.serializer.dumps(req, original_resp.raw, data)
)

0 comments on commit 6b45d11

Please sign in to comment.