Implemented ASGI lifespan #272

Also did a little bit of lint cleanup
This commit is contained in:
Simon Willison 2019-06-23 13:31:03 -07:00
commit 28c31b228d
3 changed files with 48 additions and 19 deletions

View file

@ -1,11 +1,9 @@
import asyncio
import collections
import hashlib
import json
import os
import sys
import threading
import time
import traceback
import urllib.parse
from concurrent import futures
@ -14,7 +12,6 @@ from pathlib import Path
import click
from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import DatasetteError, ureg, AsgiRouter
@ -36,8 +33,14 @@ from .utils import (
sqlite_timelimit,
to_css_class,
)
from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json, asgi_send_redirect
from .tracer import capture_traces, trace, AsgiTracer
from .utils.asgi import (
AsgiLifespan,
asgi_static,
asgi_send_html,
asgi_send_json,
asgi_send_redirect,
)
from .tracer import trace, AsgiTracer
from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__
@ -553,7 +556,6 @@ class Datasette:
def app(self):
"Returns an ASGI app function that serves the whole of Datasette"
# TODO: re-implement ?_trace= mechanism, see class TracingSanic
default_templates = str(app_root / "datasette" / "templates")
template_paths = []
if self.template_dir:
@ -665,7 +667,6 @@ class Datasette:
async def handle_500(self, scope, receive, send, exception):
title = None
help = None
if isinstance(exception, NotFound):
status = 404
info = {}
@ -703,13 +704,10 @@ class Datasette:
send, template.render(info), status=status, headers=headers
)
app = DatasetteRouter(routes)
# First time server starts up, calculate table counts for immutable databases
# TODO: re-enable this mechanism
# @app.listener("before_server_start")
# async def setup_db(app, loop):
# for dbname, database in self.databases.items():
# if not database.is_mutable:
# await database.table_counts(limit=60 * 60 * 1000)
async def setup_db():
# First time server starts up, calculate table counts for immutable databases
for dbname, database in self.databases.items():
if not database.is_mutable:
await database.table_counts(limit=60 * 60 * 1000)
return AsgiTracer(app)
return AsgiLifespan(AsgiTracer(DatasetteRouter(routes)), on_startup=setup_db)

View file

@ -33,15 +33,15 @@ def trace(type, **kwargs):
start = time.time()
yield
end = time.time()
trace = {
trace_info = {
"type": type,
"start": start,
"end": end,
"duration_ms": (end - start) * 1000,
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
}
trace.update(kwargs)
tracer.append(trace)
trace_info.update(kwargs)
tracer.append(trace_info)
@contextmanager

View file

@ -3,6 +3,7 @@ from mimetypes import guess_type
from sanic.views import HTTPMethodView
from sanic.request import Request as SanicRequest
from pathlib import Path
from html import escape
import re
import aiofiles
@ -51,6 +52,36 @@ class AsgiRouter:
await send({"type": "http.response.body", "body": html.encode("utf8")})
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
print("Wrapping {}".format(app))
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiView(HTTPMethodView):
@classmethod
def as_asgi(cls, *class_args, **class_kwargs):