Implemented custom 404/500, more tests pass #272

This commit is contained in:
Simon Willison 2019-06-22 18:57:10 -07:00
commit 55fc993667
3 changed files with 102 additions and 3 deletions

View file

@ -644,7 +644,50 @@ class Datasette:
)
self.register_custom_units()
app = AsgiRouter(routes)
outer_self = self
class DatasetteRouter(AsgiRouter):
async def handle_500(self, scope, receive, send, exception):
title = None
help = None
if isinstance(exception, NotFound):
status = 404
info = {}
message = exception.args[0]
elif isinstance(exception, InvalidUsage):
status = 405
info = {}
message = exception.args[0]
elif isinstance(exception, DatasetteError):
status = exception.status
info = exception.error_dict
message = exception.message
if exception.messagge_is_html:
message = Markup(message)
title = exception.title
else:
status = 500
info = {}
message = str(exception)
traceback.print_exc()
templates = ["500.html"]
if status != 500:
templates = ["{}.html".format(status)] + templates
info.update(
{"ok": False, "error": message, "status": status, "title": title}
)
headers = {}
if outer_self.cors:
headers["Access-Control-Allow-Origin"] = "*"
if scope["path"].split("?")[0].endswith(".json"):
await asgi_send_json(send, info, status=status, headers=headers)
else:
template = outer_self.jinja_env.select_template(templates)
await asgi_send_html(
send, template.render(info), status=status, headers=headers
)
app = DatasetteRouter(routes)
# On 404 with a trailing slash redirect to path without that slash:
# pylint: disable=unused-variable
# TODO: re-enable this
@ -665,3 +708,37 @@ class Datasette:
# await database.table_counts(limit=60 * 60 * 1000)
return app
async def asgi_send_json(send, info, status=200, headers=None):
headers = headers or {}
await asgi_send(
send,
json.dumps(info),
status=status,
headers=headers,
content_type="application/json",
)
async def asgi_send_html(send, html, status=200, headers=None):
headers = headers or {}
await asgi_send(
send, html, status=status, headers=headers, content_type="text/html"
)
async def asgi_send(send, content, status, headers, content_type="text/plain"):
# TODO: watch out for Content-Type due to mixed case:
headers["content-type"] = content_type
await send(
{
"type": "http.response.start",
"status": status,
"headers": [
[key.encode("latin1"), value.encode("latin1")]
for key, value in headers.items()
],
}
)
await send({"type": "http.response.body", "body": content.encode("utf8")})

View file

@ -12,6 +12,8 @@ from sanic.exceptions import NotFound
from sanic.views import HTTPMethodView
from sanic.request import Request as SanicRequest
from html import escape
from datasette import __version__
from datasette.plugins import pm
from datasette.utils import (
@ -64,7 +66,10 @@ class AsgiRouter:
match = regex.match(scope["path"])
if match is not None:
new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
return await view(new_scope, receive, send)
try:
return await view(new_scope, receive, send)
except Exception as exception:
return await self.handle_500(scope, receive, send, exception)
return await self.handle_404(scope, receive, send)
async def handle_404(self, scope, receive, send):
@ -77,6 +82,17 @@ class AsgiRouter:
)
await send({"type": "http.response.body", "body": b"<h1>404</h1>"})
async def handle_500(self, scope, receive, send, exception):
await send(
{
"type": "http.response.start",
"status": 404,
"headers": [[b"content-type", b"text/html"]],
}
)
html = "<h1>500</h1><pre{}></pre>".format(escape(repr(exception)))
await send({"type": "http.response.body", "body": html.encode("utf8")})
class AsgiView(HTTPMethodView):
@classmethod

View file

@ -24,6 +24,10 @@ class TestResponse:
def json(self):
return json.loads(self.body)
@property
def text(self):
return self.body.decode("utf8")
class TestClient:
def __init__(self, asgi_app):
@ -49,7 +53,9 @@ class TestClient:
# First message back should be response.start with headers and status
start = await instance.receive_output(2)
assert start["type"] == "http.response.start"
headers = start["headers"]
headers = dict(
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
)
status = start["status"]
# Now loop until we run out of response.body
body = b""