diff --git a/datasette/app.py b/datasette/app.py index 367f38f9..358081ef 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -2338,10 +2338,13 @@ class Datasette: if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) + async def _close_on_shutdown(): + self.close() + asgi = CrossOriginProtectionMiddleware(DatasetteRouter(self, routes), self) if self.setting("trace_debug"): asgi = AsgiTracer(asgi) - asgi = AsgiLifespan(asgi) + asgi = AsgiLifespan(asgi, on_shutdown=[_close_on_shutdown]) asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) for wrapper in pm.hook.asgi_wrapper(datasette=self): asgi = wrapper(asgi) diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index 5f773658..11463eda 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -256,6 +256,29 @@ async def test_datasette_close_raises_on_use(): await ds.get_internal_database().execute("select 1") +@pytest.mark.asyncio +async def test_asgi_lifespan_shutdown_closes_datasette(): + ds = Datasette(memory=True) + app = ds.app() + # Drive an ASGI lifespan: startup, then shutdown. + messages_sent = [] + inbox = [ + {"type": "lifespan.startup"}, + {"type": "lifespan.shutdown"}, + ] + + async def receive(): + return inbox.pop(0) + + async def send(message): + messages_sent.append(message) + + await app({"type": "lifespan"}, receive, send) + assert {"type": "lifespan.startup.complete"} in messages_sent + assert {"type": "lifespan.shutdown.complete"} in messages_sent + assert ds._closed + + @pytest.mark.asyncio async def test_datasette_close_continues_past_db_error(): # If one Database raises during close(), the others still get closed.