From a8bcafc1775c8a8655b365ae22a3d64f6361c74a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 28 Jun 2020 13:45:17 -0700 Subject: [PATCH] Refactored out AsgiRouter, refs #870 --- datasette/app.py | 34 ++++++++++++++++++++++----- datasette/utils/asgi.py | 52 ----------------------------------------- datasette/views/base.py | 1 - 3 files changed, 28 insertions(+), 59 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 0437a75b..bff01bc1 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -25,7 +25,7 @@ from jinja2.environment import Template from jinja2.exceptions import TemplateNotFound import uvicorn -from .views.base import DatasetteError, ureg, AsgiRouter +from .views.base import DatasetteError, ureg from .views.database import DatabaseDownload, DatabaseView from .views.index import IndexView from .views.special import ( @@ -902,10 +902,23 @@ class Datasette: return asgi -class DatasetteRouter(AsgiRouter): +class DatasetteRouter: def __init__(self, datasette, routes): self.ds = datasette - super().__init__(routes) + routes = routes or [] + self.routes = [ + # Compile any strings to regular expressions + ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) + for pattern, view in routes + ] + + async def __call__(self, scope, receive, send): + # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves + path = scope["path"] + raw_path = scope.get("raw_path") + if raw_path: + path = raw_path.decode("ascii") + return await self.route_path(scope, receive, send, path) async def route_path(self, scope, receive, send, path): # Strip off base_url if present before routing @@ -933,9 +946,18 @@ class DatasetteRouter(AsgiRouter): if actor: break scope_modifications["actor"] = actor or default_actor - return await super().route_path( - dict(scope, **scope_modifications), receive, send, path - ) + scope = dict(scope, **scope_modifications) + for regex, view in self.routes: + match = regex.match(path) + if match is not None: + new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) + try: + return await view(new_scope, receive, send) + except NotFound as exception: + return await self.handle_404(scope, receive, send, exception) + except Exception as exception: + return await self.handle_500(scope, receive, send, exception) + return await self.handle_404(scope, receive, send) async def handle_404(self, scope, receive, send, exception=None): # If URL has a trailing slash, redirect to URL without it diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index 5a152570..615bc0ab 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -118,58 +118,6 @@ class Request: return cls(scope, None) -class AsgiRouter: - def __init__(self, routes=None): - routes = routes or [] - self.routes = [ - # Compile any strings to regular expressions - ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) - for pattern, view in routes - ] - - async def __call__(self, scope, receive, send): - # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves - path = scope["path"] - raw_path = scope.get("raw_path") - if raw_path: - path = raw_path.decode("ascii") - return await self.route_path(scope, receive, send, path) - - async def route_path(self, scope, receive, send, path): - for regex, view in self.routes: - match = regex.match(path) - if match is not None: - new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) - try: - return await view(new_scope, receive, send) - except NotFound as exception: - return await self.handle_404(scope, receive, send, exception) - except Exception as exception: - return await self.handle_500(scope, receive, send, exception) - return await self.handle_404(scope, receive, send) - - async def handle_404(self, scope, receive, send, exception=None): - await send( - { - "type": "http.response.start", - "status": 404, - "headers": [[b"content-type", b"text/html; charset=utf-8"]], - } - ) - await send({"type": "http.response.body", "body": b"

404

"}) - - async def handle_500(self, scope, receive, send, exception): - await send( - { - "type": "http.response.start", - "status": 404, - "headers": [[b"content-type", b"text/html; charset=utf-8"]], - } - ) - html = "

500

".format(escape(repr(exception))) - await send({"type": "http.response.body", "body": html.encode("utf-8")}) - - class AsgiLifespan: def __init__(self, app, on_startup=None, on_shutdown=None): self.app = app diff --git a/datasette/views/base.py b/datasette/views/base.py index f14e6d3a..821a6f0e 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -27,7 +27,6 @@ from datasette.utils import ( from datasette.utils.asgi import ( AsgiStream, AsgiWriter, - AsgiRouter, AsgiView, Forbidden, NotFound,