diff --git a/datasette/database.py b/datasette/database.py index e3c4bfec..657adfa5 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -84,6 +84,8 @@ class Database: self._write_thread = None self._write_queue = None self._closed = False + self._pending_execute_futures = set() + self._pending_execute_futures_lock = threading.Lock() # These are used when in non-threaded mode: self._read_connection = None self._write_connection = None @@ -98,6 +100,10 @@ class Database: "Database {!r} has been closed".format(self.name) ) + def _remove_pending_execute_future(self, future): + with self._pending_execute_futures_lock: + self._pending_execute_futures.discard(future) + @property def cached_table_counts(self): if self._cached_table_counts is not None: @@ -170,7 +176,11 @@ class Database: """ if self._closed: return - self._closed = True + with self._pending_execute_futures_lock: + if self._closed: + return + self._closed = True + pending_execute_futures = tuple(self._pending_execute_futures) # Shut down the write thread, if any, via a sentinel. The thread # drains any writes already queued before the sentinel and then # closes its own write connection and returns. @@ -185,6 +195,11 @@ class Database: ) ) sys.stderr.flush() + for future in pending_execute_futures: + try: + future.result() + except Exception: + pass # Close anything still tracked in _all_file_connections for connection in self._all_file_connections: try: @@ -456,9 +471,12 @@ class Database: setattr(connections, self._thread_local_id, conn) return fn(conn) - return await asyncio.get_event_loop().run_in_executor( - self.ds.executor, in_thread - ) + with self._pending_execute_futures_lock: + self._check_not_closed() + future = self.ds.executor.submit(in_thread) + self._pending_execute_futures.add(future) + future.add_done_callback(self._remove_pending_execute_future) + return await asyncio.wrap_future(future) async def execute( self, diff --git a/datasette/version.py b/datasette/version.py index cf908bb2..898d388c 100644 --- a/datasette/version.py +++ b/datasette/version.py @@ -1,2 +1,2 @@ -__version__ = "1.0a28" +__version__ = "1.0a28.post1" __version_info__ = tuple(__version__.split(".")) diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index d58c9a29..3f867eb0 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -2,8 +2,11 @@ Tests for the datasette.app.Datasette class """ +import asyncio import dataclasses import os +import sqlite3 +import time from datasette import Context from datasette.app import Datasette, Database, ResourcesSQL from datasette.database import DatasetteClosedError @@ -256,6 +259,52 @@ async def test_datasette_close_raises_on_use(): await ds.get_internal_database().execute("select 1") +async def _datasette_with_sleeping_execute(tmp_path, sleep_ms=200): + db_path = tmp_path / "data.db" + internal_path = tmp_path / "internal.db" + sqlite3.connect(db_path).close() + ds = Datasette([str(db_path)], internal=str(internal_path)) + loop = asyncio.get_running_loop() + sql_started = asyncio.Event() + original_prepare_connection = ds._prepare_connection + + def prepare_connection(conn, name): + original_prepare_connection(conn, name) + + def sleep_ms(ms): + loop.call_soon_threadsafe(sql_started.set) + time.sleep(ms / 1000) + return ms + + conn.create_function("sleep_ms", 1, sleep_ms) + + ds._prepare_connection = prepare_connection + task = asyncio.create_task( + ds.get_database().execute( + f"select sleep_ms({sleep_ms})", custom_time_limit=1000 + ) + ) + await asyncio.wait_for(sql_started.wait(), timeout=5) + return ds, task + + +@pytest.mark.asyncio +async def test_datasette_close_waits_for_in_flight_execute(tmp_path): + ds, task = await _datasette_with_sleeping_execute(tmp_path) + ds.close() + results = await task + assert [tuple(row) for row in results.rows] == [(200,)] + + +@pytest.mark.asyncio +async def test_datasette_close_waits_for_cancelled_in_flight_execute(tmp_path): + ds, task = await _datasette_with_sleeping_execute(tmp_path) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + ds.close() + + @pytest.mark.asyncio async def test_asgi_lifespan_shutdown_closes_datasette(): ds = Datasette(memory=True)