From 711767bcd3c1e76a0861fe7f24069ff1c8efc97a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 18 Mar 2022 21:03:08 -0700 Subject: [PATCH] Refactored URL routing to add tests, closes #1666 Refs #1660 --- datasette/app.py | 54 ++++++++++++++++++++----------------- datasette/utils/__init__.py | 8 ++++++ tests/test_routes.py | 34 +++++++++++++++++++++++ 3 files changed, 72 insertions(+), 24 deletions(-) create mode 100644 tests/test_routes.py diff --git a/datasette/app.py b/datasette/app.py index f52e3283..8987112c 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -60,6 +60,7 @@ from .utils import ( module_from_path, parse_metadata, resolve_env_secrets, + resolve_routes, to_css_class, ) from .utils.asgi import ( @@ -974,8 +975,7 @@ class Datasette: output.append(script) return output - def app(self): - """Returns an ASGI app function that serves the whole of Datasette""" + def _routes(self): routes = [] for routes_to_add in pm.hook.register_routes(datasette=self): @@ -1099,6 +1099,15 @@ class Datasette: + renderer_regex + r")?$", ) + return [ + # Compile any strings to regular expressions + ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) + for pattern, view in routes + ] + + def app(self): + """Returns an ASGI app function that serves the whole of Datasette""" + routes = self._routes() self._register_custom_units() async def setup_db(): @@ -1129,12 +1138,7 @@ class Datasette: class DatasetteRouter: def __init__(self, datasette, routes): self.ds = datasette - routes = routes or [] - self.routes = [ - # Compile any strings to regular expressions - ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) - for pattern, view in routes - ] + self.routes = routes or [] # Build a list of pages/blah/{name}.html matching expressions pattern_templates = [ filepath @@ -1187,22 +1191,24 @@ class DatasetteRouter: break scope_modifications["actor"] = actor or default_actor scope = dict(scope, **scope_modifications) - for regex, view in self.routes: - match = regex.match(path) - if match is not None: - new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) - request.scope = new_scope - try: - response = await view(request, send) - if response: - self.ds._write_messages_to_response(request, response) - await response.asgi_send(send) - return - except NotFound as exception: - return await self.handle_404(request, send, exception) - except Exception as exception: - return await self.handle_500(request, send, exception) - return await self.handle_404(request, send) + + match, view = resolve_routes(self.routes, path) + + if match is None: + return await self.handle_404(request, send) + + new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) + request.scope = new_scope + try: + response = await view(request, send) + if response: + self.ds._write_messages_to_response(request, response) + await response.asgi_send(send) + return + except NotFound as exception: + return await self.handle_404(request, send, exception) + except Exception as exception: + return await self.handle_500(request, send, exception) async def handle_404(self, request, send, exception=None): # If path contains % encoding, redirect to tilde encoding diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index bd591459..ccdf8ad4 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1178,3 +1178,11 @@ def tilde_decode(s: str) -> str: s = s.replace("%", temp) decoded = urllib.parse.unquote(s.replace("~", "%")) return decoded.replace(temp, "%") + + +def resolve_routes(routes, path): + for regex, view in routes: + match = regex.match(path) + if match is not None: + return match, view + return None, None diff --git a/tests/test_routes.py b/tests/test_routes.py new file mode 100644 index 00000000..a1960f14 --- /dev/null +++ b/tests/test_routes.py @@ -0,0 +1,34 @@ +from datasette.app import Datasette +from datasette.utils import resolve_routes +import pytest + + +@pytest.fixture(scope="session") +def routes(): + ds = Datasette() + return ds._routes() + + +@pytest.mark.parametrize( + "path,expected", + ( + ("/", "IndexView"), + ("/foo", "DatabaseView"), + ("/foo.csv", "DatabaseView"), + ("/foo.json", "DatabaseView"), + ("/foo.humbug", "DatabaseView"), + ("/foo/humbug", "TableView"), + ("/foo/humbug.json", "TableView"), + ("/foo/humbug.blah", "TableView"), + ("/foo/humbug/1", "RowView"), + ("/foo/humbug/1.json", "RowView"), + ("/-/metadata.json", "JsonDataView"), + ("/-/metadata", "JsonDataView"), + ), +) +def test_routes(routes, path, expected): + match, view = resolve_routes(routes, path) + if expected is None: + assert match is None + else: + assert view.view_class.__name__ == expected