Implemented ASGI lifespan #272

Also did a little bit of lint cleanup
This commit is contained in:
Simon Willison 2019-06-23 13:31:03 -07:00
commit 28c31b228d
3 changed files with 48 additions and 19 deletions

View file

@ -1,11 +1,9 @@
import asyncio import asyncio
import collections import collections
import hashlib import hashlib
import json
import os import os
import sys import sys
import threading import threading
import time
import traceback import traceback
import urllib.parse import urllib.parse
from concurrent import futures from concurrent import futures
@ -14,7 +12,6 @@ from pathlib import Path
import click import click
from markupsafe import Markup from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound from sanic.exceptions import InvalidUsage, NotFound
from .views.base import DatasetteError, ureg, AsgiRouter from .views.base import DatasetteError, ureg, AsgiRouter
@ -36,8 +33,14 @@ from .utils import (
sqlite_timelimit, sqlite_timelimit,
to_css_class, to_css_class,
) )
from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json, asgi_send_redirect from .utils.asgi import (
from .tracer import capture_traces, trace, AsgiTracer AsgiLifespan,
asgi_static,
asgi_send_html,
asgi_send_json,
asgi_send_redirect,
)
from .tracer import trace, AsgiTracer
from .plugins import pm, DEFAULT_PLUGINS from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__ from .version import __version__
@ -553,7 +556,6 @@ class Datasette:
def app(self): def app(self):
"Returns an ASGI app function that serves the whole of Datasette" "Returns an ASGI app function that serves the whole of Datasette"
# TODO: re-implement ?_trace= mechanism, see class TracingSanic
default_templates = str(app_root / "datasette" / "templates") default_templates = str(app_root / "datasette" / "templates")
template_paths = [] template_paths = []
if self.template_dir: if self.template_dir:
@ -665,7 +667,6 @@ class Datasette:
async def handle_500(self, scope, receive, send, exception): async def handle_500(self, scope, receive, send, exception):
title = None title = None
help = None
if isinstance(exception, NotFound): if isinstance(exception, NotFound):
status = 404 status = 404
info = {} info = {}
@ -703,13 +704,10 @@ class Datasette:
send, template.render(info), status=status, headers=headers send, template.render(info), status=status, headers=headers
) )
app = DatasetteRouter(routes) async def setup_db():
# First time server starts up, calculate table counts for immutable databases # First time server starts up, calculate table counts for immutable databases
# TODO: re-enable this mechanism for dbname, database in self.databases.items():
# @app.listener("before_server_start") if not database.is_mutable:
# async def setup_db(app, loop): await database.table_counts(limit=60 * 60 * 1000)
# for dbname, database in self.databases.items():
# if not database.is_mutable:
# await database.table_counts(limit=60 * 60 * 1000)
return AsgiTracer(app) return AsgiLifespan(AsgiTracer(DatasetteRouter(routes)), on_startup=setup_db)

View file

@ -33,15 +33,15 @@ def trace(type, **kwargs):
start = time.time() start = time.time()
yield yield
end = time.time() end = time.time()
trace = { trace_info = {
"type": type, "type": type,
"start": start, "start": start,
"end": end, "end": end,
"duration_ms": (end - start) * 1000, "duration_ms": (end - start) * 1000,
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
} }
trace.update(kwargs) trace_info.update(kwargs)
tracer.append(trace) tracer.append(trace_info)
@contextmanager @contextmanager

View file

@ -3,6 +3,7 @@ from mimetypes import guess_type
from sanic.views import HTTPMethodView from sanic.views import HTTPMethodView
from sanic.request import Request as SanicRequest from sanic.request import Request as SanicRequest
from pathlib import Path from pathlib import Path
from html import escape
import re import re
import aiofiles import aiofiles
@ -51,6 +52,36 @@ class AsgiRouter:
await send({"type": "http.response.body", "body": html.encode("utf8")}) await send({"type": "http.response.body", "body": html.encode("utf8")})
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
print("Wrapping {}".format(app))
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiView(HTTPMethodView): class AsgiView(HTTPMethodView):
@classmethod @classmethod
def as_asgi(cls, *class_args, **class_kwargs): def as_asgi(cls, *class_args, **class_kwargs):