Refactored out AsgiRouter, refs #870

This commit is contained in:
Simon Willison 2020-06-28 13:45:17 -07:00
commit a8bcafc177
3 changed files with 28 additions and 59 deletions

View file

@ -25,7 +25,7 @@ from jinja2.environment import Template
from jinja2.exceptions import TemplateNotFound from jinja2.exceptions import TemplateNotFound
import uvicorn import uvicorn
from .views.base import DatasetteError, ureg, AsgiRouter from .views.base import DatasetteError, ureg
from .views.database import DatabaseDownload, DatabaseView from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView from .views.index import IndexView
from .views.special import ( from .views.special import (
@ -902,10 +902,23 @@ class Datasette:
return asgi return asgi
class DatasetteRouter(AsgiRouter): class DatasetteRouter:
def __init__(self, datasette, routes): def __init__(self, datasette, routes):
self.ds = datasette self.ds = datasette
super().__init__(routes) routes = routes or []
self.routes = [
# Compile any strings to regular expressions
((re.compile(pattern) if isinstance(pattern, str) else pattern), view)
for pattern, view in routes
]
async def __call__(self, scope, receive, send):
# Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves
path = scope["path"]
raw_path = scope.get("raw_path")
if raw_path:
path = raw_path.decode("ascii")
return await self.route_path(scope, receive, send, path)
async def route_path(self, scope, receive, send, path): async def route_path(self, scope, receive, send, path):
# Strip off base_url if present before routing # Strip off base_url if present before routing
@ -933,9 +946,18 @@ class DatasetteRouter(AsgiRouter):
if actor: if actor:
break break
scope_modifications["actor"] = actor or default_actor scope_modifications["actor"] = actor or default_actor
return await super().route_path( scope = dict(scope, **scope_modifications)
dict(scope, **scope_modifications), receive, send, path for regex, view in self.routes:
) match = regex.match(path)
if match is not None:
new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
try:
return await view(new_scope, receive, send)
except NotFound as exception:
return await self.handle_404(scope, receive, send, exception)
except Exception as exception:
return await self.handle_500(scope, receive, send, exception)
return await self.handle_404(scope, receive, send)
async def handle_404(self, scope, receive, send, exception=None): async def handle_404(self, scope, receive, send, exception=None):
# If URL has a trailing slash, redirect to URL without it # If URL has a trailing slash, redirect to URL without it

View file

@ -118,58 +118,6 @@ class Request:
return cls(scope, None) return cls(scope, None)
class AsgiRouter:
def __init__(self, routes=None):
routes = routes or []
self.routes = [
# Compile any strings to regular expressions
((re.compile(pattern) if isinstance(pattern, str) else pattern), view)
for pattern, view in routes
]
async def __call__(self, scope, receive, send):
# Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves
path = scope["path"]
raw_path = scope.get("raw_path")
if raw_path:
path = raw_path.decode("ascii")
return await self.route_path(scope, receive, send, path)
async def route_path(self, scope, receive, send, path):
for regex, view in self.routes:
match = regex.match(path)
if match is not None:
new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
try:
return await view(new_scope, receive, send)
except NotFound as exception:
return await self.handle_404(scope, receive, send, exception)
except Exception as exception:
return await self.handle_500(scope, receive, send, exception)
return await self.handle_404(scope, receive, send)
async def handle_404(self, scope, receive, send, exception=None):
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/html; charset=utf-8"]],
}
)
await send({"type": "http.response.body", "body": b"<h1>404</h1>"})
async def handle_500(self, scope, receive, send, exception):
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/html; charset=utf-8"]],
}
)
html = "<h1>500</h1><pre{}></pre>".format(escape(repr(exception)))
await send({"type": "http.response.body", "body": html.encode("utf-8")})
class AsgiLifespan: class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None): def __init__(self, app, on_startup=None, on_shutdown=None):
self.app = app self.app = app

View file

@ -27,7 +27,6 @@ from datasette.utils import (
from datasette.utils.asgi import ( from datasette.utils.asgi import (
AsgiStream, AsgiStream,
AsgiWriter, AsgiWriter,
AsgiRouter,
AsgiView, AsgiView,
Forbidden, Forbidden,
NotFound, NotFound,