mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Replace AsgiLifespan with AsgiRunOnFirstRequest, refs #1955
This commit is contained in:
parent
4ba8d57bb1
commit
96b3a86d7f
4 changed files with 22 additions and 49 deletions
|
|
@ -63,17 +63,14 @@ from .utils import (
|
|||
to_css_class,
|
||||
)
|
||||
from .utils.asgi import (
|
||||
AsgiLifespan,
|
||||
Base400,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Request,
|
||||
Response,
|
||||
AsgiRunOnFirstRequest,
|
||||
asgi_static,
|
||||
asgi_send,
|
||||
asgi_send_file,
|
||||
asgi_send_html,
|
||||
asgi_send_json,
|
||||
asgi_send_redirect,
|
||||
)
|
||||
from .utils.internal_db import init_internal_db, populate_schema_tables
|
||||
|
|
@ -1260,7 +1257,7 @@ class Datasette:
|
|||
|
||||
async def setup_db():
|
||||
# First time server starts up, calculate table counts for immutable databases
|
||||
for dbname, database in self.databases.items():
|
||||
for database in self.databases.values():
|
||||
if not database.is_mutable:
|
||||
await database.table_counts(limit=60 * 60 * 1000)
|
||||
|
||||
|
|
@ -1274,10 +1271,7 @@ class Datasette:
|
|||
)
|
||||
if self.setting("trace_debug"):
|
||||
asgi = AsgiTracer(asgi)
|
||||
asgi = AsgiLifespan(
|
||||
asgi,
|
||||
on_startup=setup_db,
|
||||
)
|
||||
asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup])
|
||||
for wrapper in pm.hook.asgi_wrapper(datasette=self):
|
||||
asgi = wrapper(asgi)
|
||||
return asgi
|
||||
|
|
@ -1566,42 +1560,34 @@ class DatasetteClient:
|
|||
return path
|
||||
|
||||
async def get(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.get(self._fix(path), **kwargs)
|
||||
|
||||
async def options(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.options(self._fix(path), **kwargs)
|
||||
|
||||
async def head(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.head(self._fix(path), **kwargs)
|
||||
|
||||
async def post(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.post(self._fix(path), **kwargs)
|
||||
|
||||
async def put(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.put(self._fix(path), **kwargs)
|
||||
|
||||
async def patch(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.patch(self._fix(path), **kwargs)
|
||||
|
||||
async def delete(self, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.delete(self._fix(path), **kwargs)
|
||||
|
||||
async def request(self, method, path, **kwargs):
|
||||
await self.ds.invoke_startup()
|
||||
avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None)
|
||||
async with httpx.AsyncClient(app=self.app) as client:
|
||||
return await client.request(
|
||||
|
|
|
|||
|
|
@ -135,35 +135,6 @@ class Request:
|
|||
return cls(scope, None)
|
||||
|
||||
|
||||
class AsgiLifespan:
|
||||
def __init__(self, app, on_startup=None, on_shutdown=None):
|
||||
self.app = app
|
||||
on_startup = on_startup or []
|
||||
on_shutdown = on_shutdown or []
|
||||
if not isinstance(on_startup or [], list):
|
||||
on_startup = [on_startup]
|
||||
if not isinstance(on_shutdown or [], list):
|
||||
on_shutdown = [on_shutdown]
|
||||
self.on_startup = on_startup
|
||||
self.on_shutdown = on_shutdown
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "lifespan":
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
for fn in self.on_startup:
|
||||
await fn()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
for fn in self.on_shutdown:
|
||||
await fn()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class AsgiStream:
|
||||
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
|
||||
self.stream_fn = stream_fn
|
||||
|
|
@ -428,3 +399,18 @@ class AsgiFileDownload:
|
|||
content_type=self.content_type,
|
||||
headers=self.headers,
|
||||
)
|
||||
|
||||
|
||||
class AsgiRunOnFirstRequest:
|
||||
def __init__(self, asgi, on_startup):
|
||||
assert isinstance(on_startup, list)
|
||||
self.asgi = asgi
|
||||
self.on_startup = on_startup
|
||||
self._started = False
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if not self._started:
|
||||
self._started = True
|
||||
for hook in self.on_startup:
|
||||
await hook()
|
||||
return await self.asgi(scope, receive, send)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue