Remove hashed URL mode

Also simplified how view class routing works.

Refs #1661
This commit is contained in:
Simon Willison 2022-03-18 17:12:03 -07:00 committed by GitHub
commit d4f60c2388
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 79 additions and 266 deletions

View file

@ -122,11 +122,11 @@ class BaseView:
async def delete(self, request, *args, **kwargs):
return Response.text("Method not allowed", status=405)
async def dispatch_request(self, request, *args, **kwargs):
async def dispatch_request(self, request):
if self.ds:
await self.ds.refresh_schemas()
handler = getattr(self, request.method.lower(), None)
return await handler(request, *args, **kwargs)
return await handler(request)
async def render(self, templates, request, context=None):
context = context or {}
@ -169,9 +169,7 @@ class BaseView:
def as_view(cls, *class_args, **class_kwargs):
async def view(request, send):
self = view.view_class(*class_args, **class_kwargs)
return await self.dispatch_request(
request, **request.scope["url_route"]["kwargs"]
)
return await self.dispatch_request(request)
view.view_class = cls
view.__doc__ = cls.__doc__
@ -200,90 +198,14 @@ class DataView(BaseView):
add_cors_headers(r.headers)
return r
async def data(self, request, database, hash, **kwargs):
async def data(self, request):
raise NotImplementedError
async def resolve_db_name(self, request, db_name, **kwargs):
hash = None
name = None
decoded_name = tilde_decode(db_name)
if decoded_name not in self.ds.databases and "-" in db_name:
# No matching DB found, maybe it's a name-hash?
name_bit, hash_bit = db_name.rsplit("-", 1)
if tilde_decode(name_bit) not in self.ds.databases:
raise NotFound(f"Database not found: {name}")
else:
name = tilde_decode(name_bit)
hash = hash_bit
else:
name = decoded_name
try:
db = self.ds.databases[name]
except KeyError:
raise NotFound(f"Database not found: {name}")
# Verify the hash
expected = "000"
if db.hash is not None:
expected = db.hash[:HASH_LENGTH]
correct_hash_provided = expected == hash
if not correct_hash_provided:
if "table_and_format" in kwargs:
async def async_table_exists(t):
return await db.table_exists(t)
table, _format = await resolve_table_and_format(
table_and_format=tilde_decode(kwargs["table_and_format"]),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys(),
)
kwargs["table"] = table
if _format:
kwargs["as_format"] = f".{_format}"
elif kwargs.get("table"):
kwargs["table"] = tilde_decode(kwargs["table"])
should_redirect = self.ds.urls.path(f"{name}-{expected}")
if kwargs.get("table"):
should_redirect += "/" + tilde_encode(kwargs["table"])
if kwargs.get("pk_path"):
should_redirect += "/" + kwargs["pk_path"]
if kwargs.get("as_format"):
should_redirect += kwargs["as_format"]
if kwargs.get("as_db"):
should_redirect += kwargs["as_db"]
if (
(self.ds.setting("hash_urls") or "_hash" in request.args)
and
# Redirect only if database is immutable
not self.ds.databases[name].is_mutable
):
return name, expected, correct_hash_provided, should_redirect
return name, expected, correct_hash_provided, None
def get_templates(self, database, table=None):
assert NotImplemented
async def get(self, request, db_name, **kwargs):
(
database,
hash,
correct_hash_provided,
should_redirect,
) = await self.resolve_db_name(request, db_name, **kwargs)
if should_redirect:
return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(
request, database, hash, correct_hash_provided, **kwargs
)
async def as_csv(self, request, database, hash, **kwargs):
async def as_csv(self, request, database):
kwargs = {}
stream = request.args.get("_stream")
# Do not calculate facets or counts:
extra_parameters = [
@ -313,9 +235,7 @@ class DataView(BaseView):
kwargs["_size"] = "max"
# Fetch the first page
try:
response_or_template_contexts = await self.data(
request, database, hash, **kwargs
)
response_or_template_contexts = await self.data(request)
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts
elif len(response_or_template_contexts) == 4:
@ -367,10 +287,11 @@ class DataView(BaseView):
next = None
while first or (next and stream):
try:
kwargs = {}
if next:
kwargs["_next"] = next
if not first:
data, _, _ = await self.data(request, database, hash, **kwargs)
data, _, _ = await self.data(request, **kwargs)
if first:
if request.args.get("_header") != "off":
await writer.writerow(headings)
@ -445,60 +366,39 @@ class DataView(BaseView):
if not trace:
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get("table", database)
request.url_vars.get("table", database)
)
headers["content-disposition"] = disposition
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
"""Determine the format of the response from the request, from URL
parameters or from a file extension.
`args` is a dict of the path components parsed from the URL by the router.
"""
# If ?_format= is provided, use that as the format
_format = request.args.get("_format", None)
if not _format:
_format = (args.pop("as_format", None) or "").lstrip(".")
def get_format(self, request):
# Format is the bit from the path following the ., if one exists
last_path_component = request.path.split("/")[-1]
if "." in last_path_component:
return last_path_component.split(".")[-1]
else:
args.pop("as_format", None)
if "table_and_format" in args:
db = self.ds.databases[database]
return None
async def async_table_exists(t):
return await db.table_exists(t)
table, _ext_format = await resolve_table_and_format(
table_and_format=tilde_decode(args["table_and_format"]),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys(),
)
_format = _format or _ext_format
args["table"] = table
del args["table_and_format"]
elif "table" in args:
args["table"] = tilde_decode(args["table"])
return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
_format, kwargs = await self.get_format(request, database, kwargs)
async def get(self, request):
db_name = request.url_vars["db_name"]
database = tilde_decode(db_name)
_format = self.get_format(request)
data_kwargs = {}
if _format == "csv":
return await self.as_csv(request, database, hash, **kwargs)
return await self.as_csv(request, database)
if _format is None:
# HTML views default to expanding all foreign key labels
kwargs["default_labels"] = True
data_kwargs["default_labels"] = True
extra_template_data = {}
start = time.perf_counter()
status_code = None
templates = []
try:
response_or_template_contexts = await self.data(
request, database, hash, **kwargs
)
response_or_template_contexts = await self.data(request, **data_kwargs)
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts
# If it has four items, it includes an HTTP status code
@ -650,10 +550,7 @@ class DataView(BaseView):
ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit():
if correct_hash_provided:
ttl = self.ds.setting("default_cache_ttl_hashed")
else:
ttl = self.ds.setting("default_cache_ttl")
ttl = self.ds.setting("default_cache_ttl")
return self.set_response_headers(r, ttl)