From d2daa1b9f74ef33c4a819aa1d968e442328ec987 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 23 Jun 2019 07:36:54 -0700 Subject: [PATCH] Database download works again, refactored utils.py #272 Refactored utils.py into a datasette/utils package, refactored some of the ASGI helper code into datasette/utils/asgi.py --- datasette/app.py | 78 +------ datasette/{utils.py => utils/__init__.py} | 0 datasette/utils/asgi.py | 244 ++++++++++++++++++++++ datasette/views/base.py | 143 +------------ datasette/views/database.py | 7 +- 5 files changed, 250 insertions(+), 222 deletions(-) rename datasette/{utils.py => utils/__init__.py} (100%) create mode 100644 datasette/utils/asgi.py diff --git a/datasette/app.py b/datasette/app.py index 6ea208e6..f86a14b5 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,6 +1,4 @@ import asyncio -import aiofiles -from mimetypes import guess_type import collections import hashlib import json @@ -38,6 +36,7 @@ from .utils import ( sqlite_timelimit, to_css_class, ) +from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json from .tracer import capture_traces, trace from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ @@ -714,78 +713,3 @@ class Datasette: # await database.table_counts(limit=60 * 60 * 1000) return app - - -async def asgi_send_json(send, info, status=200, headers=None): - headers = headers or {} - await asgi_send( - send, - json.dumps(info), - status=status, - headers=headers, - content_type="application/json", - ) - - -async def asgi_send_html(send, html, status=200, headers=None): - headers = headers or {} - await asgi_send( - send, html, status=status, headers=headers, content_type="text/html" - ) - - -async def asgi_send(send, content, status, headers, content_type="text/plain"): - await asgi_start(send, status, headers, content_type) - await send({"type": "http.response.body", "body": content.encode("utf8")}) - - -async def asgi_start(send, status, headers, content_type="text/plain"): - # Remove any existing content-type header - headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"]) - headers["content-type"] = content_type - await send( - { - "type": "http.response.start", - "status": status, - "headers": [ - [key.encode("latin1"), value.encode("latin1")] - for key, value in headers.items() - ], - } - ) - - -def asgi_static(root_path, chunk_size=4096): - async def inner_static(scope, receive, send): - path = scope["url_route"]["kwargs"]["path"] - full_path = (Path(root_path) / path).absolute() - # Ensure full_path is within root_path to avoid weird "../" tricks - try: - full_path.relative_to(root_path) - except ValueError: - await asgi_send_html(send, "404", 404) - return - first = True - try: - async with aiofiles.open(full_path, mode="rb") as fp: - if first: - await asgi_start( - send, 200, {}, guess_type(str(full_path))[0] or "text/plain" - ) - first = False - more_body = True - while more_body: - chunk = await fp.read(chunk_size) - more_body = len(chunk) == chunk_size - await send( - { - "type": "http.response.body", - "body": chunk, - "more_body": more_body, - } - ) - except FileNotFoundError: - await asgi_send_html(send, "404", 404) - return - - return inner_static diff --git a/datasette/utils.py b/datasette/utils/__init__.py similarity index 100% rename from datasette/utils.py rename to datasette/utils/__init__.py diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py new file mode 100644 index 00000000..14ade563 --- /dev/null +++ b/datasette/utils/asgi.py @@ -0,0 +1,244 @@ +import json +from mimetypes import guess_type +from sanic.views import HTTPMethodView +from pathlib import Path +import re +import aiofiles + + +class AsgiRouter: + def __init__(self, routes=None): + routes = routes or [] + self.routes = [ + # Compile any strings to regular expressions + ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) + for pattern, view in routes + ] + + async def __call__(self, scope, receive, send): + for regex, view in self.routes: + match = regex.match(scope["path"]) + if match is not None: + new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) + try: + return await view(new_scope, receive, send) + except Exception as exception: + return await self.handle_500(scope, receive, send, exception) + return await self.handle_404(scope, receive, send) + + async def handle_404(self, scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + await send({"type": "http.response.body", "body": b"

404

"}) + + async def handle_500(self, scope, receive, send, exception): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + html = "

500

".format(escape(repr(exception))) + await send({"type": "http.response.body", "body": html.encode("utf8")}) + + +class AsgiView(HTTPMethodView): + @classmethod + def as_asgi(cls, *class_args, **class_kwargs): + async def view(scope, receive, send): + # Uses scope to create a Sanic-compatible request object, + # then dispatches that to self.get(...) or self.options(...) + # along with keyword arguments that were already tucked + # into scope["url_route"]["kwargs"] by the router + # https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter + path = scope.get("raw_path", scope["path"].encode("utf8")) + if scope["query_string"]: + path = path + b"?" + scope["query_string"] + request = SanicRequest( + path, + { + "Host": dict(scope.get("headers") or []) + .get(b"host", b"") + .decode("utf8") + }, + "1.1", + scope["method"], + None, + ) + + # TODO: Remove need for this + class Woo: + def get_extra_info(self, key): + return False + + request.app = Woo() + request.app.websocket_enabled = False + request.transport = Woo() + self = view.view_class(*class_args, **class_kwargs) + response = await self.dispatch_request( + request, **scope["url_route"]["kwargs"] + ) + if hasattr(response, "asgi_send"): + await response.asgi_send(send) + else: + await send( + { + "type": "http.response.start", + "status": response.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in response.headers.items() + ], + } + ) + await send({"type": "http.response.body", "body": response.body}) + + view.view_class = cls + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.__name__ = cls.__name__ + return view + + +class AsgiStream: + def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): + self.stream_fn = stream_fn + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + # Remove any existing content-type header + headers = dict( + [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"] + ) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + w = AsgiWriter(send) + await self.stream_fn(w) + await send({"type": "http.response.body", "body": b""}) + + +class AsgiWriter: + def __init__(self, send): + self.send = send + + async def write(self, chunk): + await self.send( + { + "type": "http.response.body", + "body": chunk.encode("utf8"), + "more_body": True, + } + ) + + +async def asgi_send_json(send, info, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, + json.dumps(info), + status=status, + headers=headers, + content_type="application/json", + ) + + +async def asgi_send_html(send, html, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, html, status=status, headers=headers, content_type="text/html" + ) + + +async def asgi_send(send, content, status, headers, content_type="text/plain"): + await asgi_start(send, status, headers, content_type) + await send({"type": "http.response.body", "body": content.encode("utf8")}) + + +async def asgi_start(send, status, headers, content_type="text/plain"): + # Remove any existing content-type header + headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"]) + headers["content-type"] = content_type + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + [key.encode("latin1"), value.encode("latin1")] + for key, value in headers.items() + ], + } + ) + + +async def asgi_send_file( + send, filepath, filename=None, content_type=None, chunk_size=4096 +): + headers = {} + if filename: + headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename) + first = True + async with aiofiles.open(filepath, mode="rb") as fp: + if first: + await asgi_start( + send, + 200, + headers, + content_type or guess_type(str(filepath))[0] or "text/plain", + ) + first = False + more_body = True + while more_body: + chunk = await fp.read(chunk_size) + more_body = len(chunk) == chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": more_body} + ) + + +def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): + async def inner_static(scope, receive, send): + path = scope["url_route"]["kwargs"]["path"] + full_path = (Path(root_path) / path).absolute() + # Ensure full_path is within root_path to avoid weird "../" tricks + try: + full_path.relative_to(root_path) + except ValueError: + await asgi_send_html(send, "404", 404) + return + first = True + try: + await asgi_send_file(send, full_path, chunk_size=chunk_size) + except FileNotFoundError: + await asgi_send_html(send, "404", 404) + return + + return inner_static + + +class AsgiFileDownload: + def __init__( + self, filepath, filename=None, content_type="application/octet-stream" + ): + self.filepath = filepath + self.filename = filename + self.content_type = content_type + + async def asgi_send(self, send): + return await asgi_send_file(send, self.filepath, content_type=self.content_type) diff --git a/datasette/views/base.py b/datasette/views/base.py index 0b02a13b..a2d6571f 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -9,7 +9,6 @@ import jinja2 import pint from sanic import response from sanic.exceptions import NotFound -from sanic.views import HTTPMethodView from sanic.request import Request as SanicRequest from html import escape @@ -29,6 +28,7 @@ from datasette.utils import ( sqlite3, to_css_class, ) +from datasette.utils.asgi import AsgiStream, AsgiWriter, AsgiRouter, AsgiView ureg = pint.UnitRegistry() @@ -52,147 +52,6 @@ class DatasetteError(Exception): self.messagge_is_html = messagge_is_html -class AsgiRouter: - def __init__(self, routes=None): - routes = routes or [] - self.routes = [ - # Compile any strings to regular expressions - ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) - for pattern, view in routes - ] - - async def __call__(self, scope, receive, send): - for regex, view in self.routes: - match = regex.match(scope["path"]) - if match is not None: - new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) - try: - return await view(new_scope, receive, send) - except Exception as exception: - return await self.handle_500(scope, receive, send, exception) - return await self.handle_404(scope, receive, send) - - async def handle_404(self, scope, receive, send): - await send( - { - "type": "http.response.start", - "status": 404, - "headers": [[b"content-type", b"text/html"]], - } - ) - await send({"type": "http.response.body", "body": b"

404

"}) - - async def handle_500(self, scope, receive, send, exception): - await send( - { - "type": "http.response.start", - "status": 404, - "headers": [[b"content-type", b"text/html"]], - } - ) - html = "

500

".format(escape(repr(exception))) - await send({"type": "http.response.body", "body": html.encode("utf8")}) - - -class AsgiView(HTTPMethodView): - @classmethod - def as_asgi(cls, *class_args, **class_kwargs): - async def view(scope, receive, send): - # Uses scope to create a Sanic-compatible request object, - # then dispatches that to self.get(...) or self.options(...) - # along with keyword arguments that were already tucked - # into scope["url_route"]["kwargs"] by the router - # https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter - path = scope.get("raw_path", scope["path"].encode("utf8")) - if scope["query_string"]: - path = path + b"?" + scope["query_string"] - request = SanicRequest( - path, - { - "Host": dict(scope.get("headers") or []) - .get(b"host", b"") - .decode("utf8") - }, - "1.1", - scope["method"], - None, - ) - - class Woo: - def get_extra_info(self, key): - return False - - request.app = Woo() - request.app.websocket_enabled = False - request.transport = Woo() - self = view.view_class(*class_args, **class_kwargs) - response = await self.dispatch_request( - request, **scope["url_route"]["kwargs"] - ) - if hasattr(response, "asgi_send"): - await response.asgi_send(send) - else: - await send( - { - "type": "http.response.start", - "status": response.status, - "headers": [ - [key.encode("utf-8"), value.encode("utf-8")] - for key, value in response.headers.items() - ], - } - ) - await send({"type": "http.response.body", "body": response.body}) - - view.view_class = cls - view.__doc__ = cls.__doc__ - view.__module__ = cls.__module__ - view.__name__ = cls.__name__ - return view - - -class AsgiStream: - def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): - self.stream_fn = stream_fn - self.status = status - self.headers = headers or {} - self.content_type = content_type - - async def asgi_send(self, send): - # Remove any existing content-type header - headers = dict( - [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"] - ) - headers["content-type"] = self.content_type - await send( - { - "type": "http.response.start", - "status": self.status, - "headers": [ - [key.encode("utf-8"), value.encode("utf-8")] - for key, value in headers.items() - ], - } - ) - w = AsgiWriter(send) - await self.stream_fn(w) - await send({"type": "http.response.body", "body": b""}) - - -class AsgiWriter: - def __init__(self, send): - self.send = send - - async def write(self, chunk): - await self.send( - { - "type": "http.response.body", - "body": chunk.encode("utf8"), - "more_body": True, - } - ) - - class BaseView(AsgiView): ds = None diff --git a/datasette/views/database.py b/datasette/views/database.py index a5b606f1..4809fef0 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -3,8 +3,9 @@ import os from sanic import response from datasette.utils import to_css_class, validate_sql_select +from datasette.utils.asgi import AsgiFileDownload -from .base import DataView, DatasetteError +from .base import DatasetteError, DataView class DatabaseView(DataView): @@ -79,8 +80,8 @@ class DatabaseDownload(DataView): if not db.path: raise DatasetteError("Cannot download database", status=404) filepath = db.path - return await response.file_stream( + return AsgiFileDownload( filepath, filename=os.path.basename(filepath), - mime_type="application/octet-stream", + content_type="application/octet-stream", )