From 28c31b228d93d22facff728d7199bcf9bc1310d4 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 23 Jun 2019 13:31:03 -0700 Subject: [PATCH] Implemented ASGI lifespan #272 Also did a little bit of lint cleanup --- datasette/app.py | 30 ++++++++++++++---------------- datasette/tracer.py | 6 +++--- datasette/utils/asgi.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 4c85a78e..4552d9d4 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,11 +1,9 @@ import asyncio import collections import hashlib -import json import os import sys import threading -import time import traceback import urllib.parse from concurrent import futures @@ -14,7 +12,6 @@ from pathlib import Path import click from markupsafe import Markup from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader -from sanic import Sanic, response from sanic.exceptions import InvalidUsage, NotFound from .views.base import DatasetteError, ureg, AsgiRouter @@ -36,8 +33,14 @@ from .utils import ( sqlite_timelimit, to_css_class, ) -from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json, asgi_send_redirect -from .tracer import capture_traces, trace, AsgiTracer +from .utils.asgi import ( + AsgiLifespan, + asgi_static, + asgi_send_html, + asgi_send_json, + asgi_send_redirect, +) +from .tracer import trace, AsgiTracer from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ @@ -553,7 +556,6 @@ class Datasette: def app(self): "Returns an ASGI app function that serves the whole of Datasette" - # TODO: re-implement ?_trace= mechanism, see class TracingSanic default_templates = str(app_root / "datasette" / "templates") template_paths = [] if self.template_dir: @@ -665,7 +667,6 @@ class Datasette: async def handle_500(self, scope, receive, send, exception): title = None - help = None if isinstance(exception, NotFound): status = 404 info = {} @@ -703,13 +704,10 @@ class Datasette: send, template.render(info), status=status, headers=headers ) - app = DatasetteRouter(routes) - # First time server starts up, calculate table counts for immutable databases - # TODO: re-enable this mechanism - # @app.listener("before_server_start") - # async def setup_db(app, loop): - # for dbname, database in self.databases.items(): - # if not database.is_mutable: - # await database.table_counts(limit=60 * 60 * 1000) + async def setup_db(): + # First time server starts up, calculate table counts for immutable databases + for dbname, database in self.databases.items(): + if not database.is_mutable: + await database.table_counts(limit=60 * 60 * 1000) - return AsgiTracer(app) + return AsgiLifespan(AsgiTracer(DatasetteRouter(routes)), on_startup=setup_db) diff --git a/datasette/tracer.py b/datasette/tracer.py index 4a46f1e6..e46a6fda 100644 --- a/datasette/tracer.py +++ b/datasette/tracer.py @@ -33,15 +33,15 @@ def trace(type, **kwargs): start = time.time() yield end = time.time() - trace = { + trace_info = { "type": type, "start": start, "end": end, "duration_ms": (end - start) * 1000, "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), } - trace.update(kwargs) - tracer.append(trace) + trace_info.update(kwargs) + tracer.append(trace_info) @contextmanager diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index 2ae8ab6e..be53627f 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -3,6 +3,7 @@ from mimetypes import guess_type from sanic.views import HTTPMethodView from sanic.request import Request as SanicRequest from pathlib import Path +from html import escape import re import aiofiles @@ -51,6 +52,36 @@ class AsgiRouter: await send({"type": "http.response.body", "body": html.encode("utf8")}) +class AsgiLifespan: + def __init__(self, app, on_startup=None, on_shutdown=None): + print("Wrapping {}".format(app)) + self.app = app + on_startup = on_startup or [] + on_shutdown = on_shutdown or [] + if not isinstance(on_startup or [], list): + on_startup = [on_startup] + if not isinstance(on_shutdown or [], list): + on_shutdown = [on_shutdown] + self.on_startup = on_startup + self.on_shutdown = on_shutdown + + async def __call__(self, scope, receive, send): + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for fn in self.on_startup: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.on_shutdown: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await self.app(scope, receive, send) + + class AsgiView(HTTPMethodView): @classmethod def as_asgi(cls, *class_args, **class_kwargs):