default_cache_ttl_hashed config and ?_hash= param

This commit is contained in:
Simon Willison 2019-03-17 15:41:34 -07:00
commit 8f003d9545
7 changed files with 72 additions and 21 deletions

View file

@ -84,9 +84,12 @@ CONFIG_OPTIONS = (
ConfigOption("allow_sql", True, """
Allow arbitrary SQL queries via ?sql= parameter
""".strip()),
ConfigOption("default_cache_ttl", 365 * 24 * 60 * 60, """
ConfigOption("default_cache_ttl", 5, """
Default HTTP cache TTL (used in Cache-Control: max-age= header)
""".strip()),
ConfigOption("default_cache_ttl_hashed", 365 * 24 * 60 * 60, """
Default HTTP cache TTL for hashed URL pages
""".strip()),
ConfigOption("cache_size_kb", 0, """
SQLite cache size in KB (0 == use SQLite default)
""".strip()),

View file

@ -208,8 +208,14 @@ def path_with_added_args(request, args, path=None):
def path_with_removed_args(request, args, path=None):
query_string = request.query_string
if path is None:
path = request.path
else:
if "?" in path:
bits = path.split("?", 1)
path, query_string = bits
# args can be a dict or a set
path = path or request.path
current = []
if isinstance(args, set):
def should_remove(key, value):
@ -218,7 +224,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(query_string):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)

View file

@ -144,16 +144,18 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*"
return r
def redirect(self, request, path, forward_querystring=True):
def redirect(self, request, path, forward_querystring=True, remove_args=None):
if request.query_string and "?" not in path and forward_querystring:
path = "{}?{}".format(path, request.query_string)
if remove_args:
path = path_with_removed_args(request, remove_args, path=path)
r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors:
r.headers["Access-Control-Allow-Origin"] = "*"
return r
def resolve_db_name(self, db_name, **kwargs):
def resolve_db_name(self, request, db_name, **kwargs):
databases = self.ds.inspect()
hash = None
name = None
@ -174,7 +176,9 @@ class BaseView(RenderMixin):
raise NotFound("Database not found: {}".format(name))
expected = info["hash"][:HASH_LENGTH]
if expected != hash:
correct_hash_provided = (expected == hash)
if not correct_hash_provided:
if "table_and_format" in kwargs:
table, _format = resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
@ -202,10 +206,10 @@ class BaseView(RenderMixin):
if "as_db" in kwargs:
should_redirect += kwargs["as_db"]
if self.ds.config("hash_urls"):
return name, expected, should_redirect
if self.ds.config("hash_urls") or "_hash" in request.args:
return name, expected, correct_hash_provided, should_redirect
return name, expected, None
return name, expected, correct_hash_provided, None
def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path)
@ -217,11 +221,13 @@ class BaseView(RenderMixin):
assert NotImplemented
async def get(self, request, db_name, **kwargs):
database, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
database, hash, correct_hash_provided, should_redirect = self.resolve_db_name(
request, db_name, **kwargs
)
if should_redirect:
return self.redirect(request, should_redirect)
return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(request, database, hash, **kwargs)
return await self.view_get(request, database, hash, correct_hash_provided, **kwargs)
async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream")
@ -316,7 +322,7 @@ class BaseView(RenderMixin):
content_type=content_type
)
async def view_get(self, request, database, hash, **kwargs):
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
# If ?_format= is provided, use that as the format
_format = request.args.get("_format", None)
if not _format:
@ -503,10 +509,13 @@ class BaseView(RenderMixin):
r = self.render(templates, **context)
r.status = status_code
# Set far-future cache expiry
if self.ds.cache_headers:
if self.ds.cache_headers and r.status == 200:
ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit():
ttl = self.ds.config("default_cache_ttl")
if correct_hash_provided:
ttl = self.ds.config("default_cache_ttl_hashed")
else:
ttl = self.ds.config("default_cache_ttl")
else:
ttl = int(ttl)
if ttl == 0:

View file

@ -40,7 +40,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView):
async def view_get(self, request, database, hash, **kwargs):
async def view_get(self, request, database, hash, correct_hash_present, **kwargs):
if not self.ds.config("allow_download"):
raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[database]["file"]

View file

@ -115,11 +115,21 @@ Enable/disable the ability for users to run custom SQL directly against a databa
default_cache_ttl
-----------------
Default HTTP caching max-age header in seconds, used for ``Cache-Control: max-age=X``. Can be over-ridden on a per-request basis using the ``?_ttl=`` querystring parameter. Set this to ``0`` to disable HTTP caching entirely. Defaults to 365 days (31536000 seconds).
Default HTTP caching max-age header in seconds, used for ``Cache-Control: max-age=X``. Can be over-ridden on a per-request basis using the ``?_ttl=`` querystring parameter. Set this to ``0`` to disable HTTP caching entirely. Defaults to 5 seconds.
::
datasette mydatabase.db --config default_cache_ttl:10
datasette mydatabase.db --config default_cache_ttl:60
default_cache_ttl_hashed
------------------------
Default HTTP caching max-age for responses served using using the :ref:`hashed-urls mechanism <config_hash_urls>`. Defaults to 365 days (31536000 seconds).
::
datasette mydatabase.db --config default_cache_ttl_hashed:10000
cache_size_kb
-------------
@ -180,6 +190,8 @@ HTTP but is served to the outside world via a proxy that enables HTTPS.
datasette mydatabase.db --config force_https_urls:1
.. _config_hash_urls:
hash_urls
---------

View file

@ -1050,7 +1050,8 @@ def test_config_json(app_client):
"allow_facet": True,
"suggest_facets": True,
"allow_sql": True,
"default_cache_ttl": 365 * 24 * 60 * 60,
"default_cache_ttl": 5,
"default_cache_ttl_hashed": 365 * 24 * 60 * 60,
"num_sql_threads": 3,
"cache_size_kb": 0,
"allow_csv_stream": True,
@ -1302,8 +1303,8 @@ def test_expand_label(app_client):
@pytest.mark.parametrize('path,expected_cache_control', [
("/fixtures/facetable.json", "max-age=31536000"),
("/fixtures/facetable.json?_ttl=invalid", "max-age=31536000"),
("/fixtures/facetable.json", "max-age=5"),
("/fixtures/facetable.json?_ttl=invalid", "max-age=5"),
("/fixtures/facetable.json?_ttl=10", "max-age=10"),
("/fixtures/facetable.json?_ttl=0", "no-cache"),
])
@ -1312,6 +1313,19 @@ def test_ttl_parameter(app_client, path, expected_cache_control):
assert expected_cache_control == response.headers['Cache-Control']
@pytest.mark.parametrize("path,expected_redirect", [
("/fixtures/facetable.json?_hash=1", "/fixtures-HASH/facetable.json"),
("/fixtures/facetable.json?city_id=1&_hash=1", "/fixtures-HASH/facetable.json?city_id=1"),
])
def test_hash_parameter(app_client, path, expected_redirect):
# First get the current hash for the fixtures database
current_hash = app_client.get("/-/inspect.json").json["fixtures"]["hash"][:7]
response = app_client.get(path, allow_redirects=False)
assert response.status == 302
location = response.headers["Location"]
assert expected_redirect.replace("HASH", current_hash) == location
test_json_columns_default_expected = [{
"intval": 1,
"strval": "s",

View file

@ -59,6 +59,13 @@ def test_path_with_removed_args(path, args, expected):
)
actual = utils.path_with_removed_args(request, args)
assert expected == actual
# Run the test again but this time use the path= argument
request = Request(
"/".encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual
@pytest.mark.parametrize('path,args,expected', [