diff --git a/datasette/app.py b/datasette/app.py index 358081ef..f141ca68 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1229,7 +1229,7 @@ class Datasette: if query: return query - def _prepare_connection(self, conn, database): + def _prepare_connection(self, conn, database, *, run_plugin_hook=True): conn.row_factory = sqlite3.Row conn.text_factory = lambda x: str(x, "utf-8", "replace") if self.sqlite_extensions and database != INTERNAL_DB_NAME: @@ -1245,7 +1245,7 @@ class Datasette: if self.setting("cache_size_kb"): conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}") # pylint: disable=no-member - if database != INTERNAL_DB_NAME: + if database != INTERNAL_DB_NAME and run_plugin_hook: pm.hook.prepare_connection(conn=conn, database=database, datasette=self) # If self.crossdb and this is _memory, connect the first SQLITE_LIMIT_ATTACHED databases if self.crossdb and database == "_memory": diff --git a/datasette/database.py b/datasette/database.py index 66d50ffa..21d62a86 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,6 +1,7 @@ import asyncio import atexit from collections import namedtuple +from contextlib import asynccontextmanager import inspect import os from pathlib import Path @@ -131,7 +132,7 @@ class Database: else: return "db" - def connect(self, write=False): + def connect(self, write=False, track=True): extra_kwargs = {} if write: extra_kwargs["isolation_level"] = "IMMEDIATE" @@ -161,7 +162,8 @@ class Database: conn = sqlite3.connect( f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs ) - self._all_file_connections.append(conn) + if track: + self._all_file_connections.append(conn) if self.is_temp_disk and not self._wal_enabled: conn.execute("PRAGMA journal_mode=WAL") self._wal_enabled = True @@ -478,17 +480,9 @@ class Database: future.add_done_callback(self._remove_pending_execute_future) return await asyncio.wrap_future(future) - async def execute( - self, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - log_sql_errors=True, + def _make_sql_operation( + self, sql, params, truncate, custom_time_limit, page_size, log_sql_errors ): - """Executes sql against db_name in a thread""" - self._check_not_closed() page_size = page_size or self.ds.page_size def sql_operation_in_thread(conn): @@ -528,10 +522,72 @@ class Database: else: return Results(rows, False, cursor.description) + return sql_operation_in_thread + + async def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + """Executes sql against db_name in a thread""" + self._check_not_closed() + sql_operation_in_thread = self._make_sql_operation( + sql, params, truncate, custom_time_limit, page_size, log_sql_errors + ) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_fn(sql_operation_in_thread) return results + async def _execute_fn_on_connection(self, conn, fn): + """Run fn(conn) on the shared executor (or inline in non-threaded mode). + + The caller owns the connection's lifecycle; this method does not + cache or close it. Used by request_connection(). + """ + self._check_not_closed() + if self.ds.executor is None: + return fn(conn) + + def in_thread(): + return fn(conn) + + with self._pending_execute_futures_lock: + self._check_not_closed() + future = self.ds.executor.submit(in_thread) + self._pending_execute_futures.add(future) + future.add_done_callback(self._remove_pending_execute_future) + return await asyncio.wrap_future(future) + + @asynccontextmanager + async def request_connection(self, write=False, run_prepare_connection_hook=False): + """Open a fresh sqlite3 connection scoped to an ``async with`` block. + + Intended for short-lived per-request work — for example, installing + a ``set_authorizer`` callback derived from ``request.actor`` before + running queries. The connection is not added to the pool, not + cached, and is closed when the context exits. + + Pass ``run_prepare_connection_hook=True`` to opt into the + ``prepare_connection`` plugin hook; by default it is skipped so + these connections stay cheap. + """ + self._check_not_closed() + conn = self.connect(write=write, track=False) + self.ds._prepare_connection( + conn, self.name, run_plugin_hook=run_prepare_connection_hook + ) + try: + yield RequestConnection(self, conn) + finally: + try: + conn.close() + except Exception: + pass + @property def hash(self): if self.cached_hash is not None: @@ -931,6 +987,38 @@ def _deliver_write_result(task, result, exception): pass +class RequestConnection: + """Thin async wrapper around a single sqlite3.Connection. + + Yielded by :meth:`Database.request_connection`. Exposes ``execute`` and + ``execute_fn`` with the same semantics as :class:`Database`, but bound + to the underlying ``connection`` so a caller can attach per-request + state (e.g. ``set_authorizer``) without touching the pool. + """ + + def __init__(self, db, connection): + self._db = db + self.connection = connection + + async def execute_fn(self, fn): + return await self._db._execute_fn_on_connection(self.connection, fn) + + async def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + fn = self._db._make_sql_operation( + sql, params, truncate, custom_time_limit, page_size, log_sql_errors + ) + with trace("sql", database=self._db.name, sql=sql.strip(), params=params): + return await self._db._execute_fn_on_connection(self.connection, fn) + + class QueryInterrupted(Exception): def __init__(self, e, sql, params): self.e = e diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index 75ae8d39..434c0e37 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -923,3 +923,156 @@ async def test_database_close_is_idempotent(tmpdir): # Second call should be a no-op, not raise db.close() ds._internal_database.close() + + +@pytest.mark.asyncio +async def test_request_connection_basic(db): + async with db.request_connection() as conn: + results = await conn.execute("select 1 + 1") + assert results.single_value() == 2 + + +@pytest.mark.asyncio +async def test_request_connection_exposes_raw_sqlite_connection(db): + async with db.request_connection() as conn: + assert isinstance(conn.connection, sqlite3.Connection) + # Direct execution on the raw connection works too + assert conn.connection.execute("select 1").fetchone()[0] == 1 + + +@pytest.mark.asyncio +async def test_request_connection_fresh_each_call(db): + async with db.request_connection() as conn_a: + async with db.request_connection() as conn_b: + assert conn_a.connection is not conn_b.connection + + +@pytest.mark.asyncio +async def test_request_connection_closes_on_exit(db): + async with db.request_connection() as conn: + raw = conn.connection + await conn.execute("select 1") + with pytest.raises(sqlite3.ProgrammingError): + raw.execute("select 1") + + +@pytest.mark.asyncio +async def test_request_connection_readonly_by_default(db): + async with db.request_connection() as conn: + with pytest.raises(sqlite3.OperationalError): + await conn.execute( + "create table should_not_exist (id integer primary key)" + ) + + +@pytest.mark.asyncio +async def test_request_connection_writable(): + ds = Datasette(memory=True) + db = ds.add_memory_database("test_request_connection_writable") + async with db.request_connection(write=True) as conn: + await conn.execute_fn( + lambda c: c.execute("create table t (id integer primary key)") + ) + await conn.execute_fn(lambda c: c.execute("insert into t (id) values (1)")) + row = await conn.execute_fn( + lambda c: c.execute("select id from t").fetchone() + ) + assert row[0] == 1 + + +@pytest.mark.asyncio +async def test_request_connection_authorizer_does_not_leak_to_pool(db): + def deny_all(action, *args): + return sqlite3.SQLITE_DENY + + async with db.request_connection() as conn: + conn.connection.set_authorizer(deny_all) + with pytest.raises(sqlite3.DatabaseError): + await conn.execute("select 1") + + # Plain pool execute should still succeed + assert (await db.execute("select 1")).single_value() == 1 + + +@pytest.mark.asyncio +async def test_request_connection_skips_prepare_connection_hook_by_default(db): + # The `fixtures` db has my_plugin loaded, which would register + # convert_units via the prepare_connection hook. The per-request + # connection should NOT have that function available by default. + async with db.request_connection() as conn: + with pytest.raises(sqlite3.OperationalError): + await conn.execute("select convert_units(100, 'm', 'ft')") + + +@pytest.mark.asyncio +async def test_request_connection_runs_prepare_connection_hook_when_opted_in(db): + async with db.request_connection(run_prepare_connection_hook=True) as conn: + result = await conn.execute("select convert_units(100, 'm', 'ft')") + assert result.single_value() == pytest.approx(328.0839) + + +@pytest.mark.asyncio +async def test_request_connection_row_factory_applied(db): + async with db.request_connection() as conn: + row = (await conn.execute("select 1 as one")).first() + assert isinstance(row, sqlite3.Row) + assert row["one"] == 1 + + +@pytest.mark.asyncio +async def test_request_connection_not_tracked_in_all_file_connections(db): + before = list(db._all_file_connections) + async with db.request_connection() as conn: + pass + assert db._all_file_connections == before + + +@pytest.mark.asyncio +async def test_request_connection_cleans_up_on_exception(db): + raw = None + with pytest.raises(RuntimeError, match="boom"): + async with db.request_connection() as conn: + raw = conn.connection + raise RuntimeError("boom") + with pytest.raises(sqlite3.ProgrammingError): + raw.execute("select 1") + + +@pytest.mark.asyncio +async def test_request_connection_execute_fn(db): + async with db.request_connection() as conn: + value = await conn.execute_fn( + lambda c: c.execute("select 41 + 1").fetchone()[0] + ) + assert value == 42 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("disable_threads", (False, True)) +async def test_request_connection_non_threaded(disable_threads): + if disable_threads: + ds = Datasette(memory=True, settings={"num_sql_threads": 0}) + else: + ds = Datasette(memory=True) + db = ds.add_memory_database("test_request_connection_non_threaded") + async with db.request_connection() as conn: + assert (await conn.execute("select 1")).single_value() == 1 + assert ( + await conn.execute_fn(lambda c: c.execute("select 2").fetchone()[0]) + ) == 2 + + +@pytest.mark.asyncio +async def test_request_connection_raises_after_database_closed(tmpdir): + path = str(tmpdir / "closed_req.db") + conn = sqlite3.connect(path) + conn.execute("create table t (id integer primary key)") + conn.close() + ds = Datasette([path]) + db = ds.get_database("closed_req") + await db.execute("select 1") + db.close() + with pytest.raises(DatasetteClosedError): + async with db.request_connection() as conn: + pass + ds._internal_database.close()