mirror of
https://github.com/simonw/datasette.git
synced 2026-05-28 12:56:18 +02:00
Add Database.request_connection() for per-request scoped connections
Opens a fresh sqlite3 connection scoped to an async-with block, intended for short-lived per-request work such as installing a set_authorizer callback derived from request.actor before running queries. The connection is not pooled, not cached, and closed when the context exits. - Database.request_connection(write=False, run_prepare_connection_hook=False) yields a RequestConnection with .connection, .execute() and .execute_fn() - Datasette._prepare_connection gains run_plugin_hook=True kwarg; the prepare_connection plugin hook is skipped by default for these connections to keep them cheap - Database.connect gains track=False to skip _all_file_connections tracking when the caller owns lifecycle - Database.execute body extracted into _make_sql_operation so both pooled execute() and RequestConnection.execute() share it https://claude.ai/code/session_01GdaNscbub6d2MPaUANrhJj
This commit is contained in:
parent
40e78e0927
commit
e34aac77a9
3 changed files with 255 additions and 14 deletions
|
|
@ -1229,7 +1229,7 @@ class Datasette:
|
|||
if query:
|
||||
return query
|
||||
|
||||
def _prepare_connection(self, conn, database):
|
||||
def _prepare_connection(self, conn, database, *, run_plugin_hook=True):
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.text_factory = lambda x: str(x, "utf-8", "replace")
|
||||
if self.sqlite_extensions and database != INTERNAL_DB_NAME:
|
||||
|
|
@ -1245,7 +1245,7 @@ class Datasette:
|
|||
if self.setting("cache_size_kb"):
|
||||
conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}")
|
||||
# pylint: disable=no-member
|
||||
if database != INTERNAL_DB_NAME:
|
||||
if database != INTERNAL_DB_NAME and run_plugin_hook:
|
||||
pm.hook.prepare_connection(conn=conn, database=database, datasette=self)
|
||||
# If self.crossdb and this is _memory, connect the first SQLITE_LIMIT_ATTACHED databases
|
||||
if self.crossdb and database == "_memory":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import atexit
|
||||
from collections import namedtuple
|
||||
from contextlib import asynccontextmanager
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
|
@ -131,7 +132,7 @@ class Database:
|
|||
else:
|
||||
return "db"
|
||||
|
||||
def connect(self, write=False):
|
||||
def connect(self, write=False, track=True):
|
||||
extra_kwargs = {}
|
||||
if write:
|
||||
extra_kwargs["isolation_level"] = "IMMEDIATE"
|
||||
|
|
@ -161,7 +162,8 @@ class Database:
|
|||
conn = sqlite3.connect(
|
||||
f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs
|
||||
)
|
||||
self._all_file_connections.append(conn)
|
||||
if track:
|
||||
self._all_file_connections.append(conn)
|
||||
if self.is_temp_disk and not self._wal_enabled:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._wal_enabled = True
|
||||
|
|
@ -478,17 +480,9 @@ class Database:
|
|||
future.add_done_callback(self._remove_pending_execute_future)
|
||||
return await asyncio.wrap_future(future)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
sql,
|
||||
params=None,
|
||||
truncate=False,
|
||||
custom_time_limit=None,
|
||||
page_size=None,
|
||||
log_sql_errors=True,
|
||||
def _make_sql_operation(
|
||||
self, sql, params, truncate, custom_time_limit, page_size, log_sql_errors
|
||||
):
|
||||
"""Executes sql against db_name in a thread"""
|
||||
self._check_not_closed()
|
||||
page_size = page_size or self.ds.page_size
|
||||
|
||||
def sql_operation_in_thread(conn):
|
||||
|
|
@ -528,10 +522,72 @@ class Database:
|
|||
else:
|
||||
return Results(rows, False, cursor.description)
|
||||
|
||||
return sql_operation_in_thread
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
sql,
|
||||
params=None,
|
||||
truncate=False,
|
||||
custom_time_limit=None,
|
||||
page_size=None,
|
||||
log_sql_errors=True,
|
||||
):
|
||||
"""Executes sql against db_name in a thread"""
|
||||
self._check_not_closed()
|
||||
sql_operation_in_thread = self._make_sql_operation(
|
||||
sql, params, truncate, custom_time_limit, page_size, log_sql_errors
|
||||
)
|
||||
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
||||
results = await self.execute_fn(sql_operation_in_thread)
|
||||
return results
|
||||
|
||||
async def _execute_fn_on_connection(self, conn, fn):
|
||||
"""Run fn(conn) on the shared executor (or inline in non-threaded mode).
|
||||
|
||||
The caller owns the connection's lifecycle; this method does not
|
||||
cache or close it. Used by request_connection().
|
||||
"""
|
||||
self._check_not_closed()
|
||||
if self.ds.executor is None:
|
||||
return fn(conn)
|
||||
|
||||
def in_thread():
|
||||
return fn(conn)
|
||||
|
||||
with self._pending_execute_futures_lock:
|
||||
self._check_not_closed()
|
||||
future = self.ds.executor.submit(in_thread)
|
||||
self._pending_execute_futures.add(future)
|
||||
future.add_done_callback(self._remove_pending_execute_future)
|
||||
return await asyncio.wrap_future(future)
|
||||
|
||||
@asynccontextmanager
|
||||
async def request_connection(self, write=False, run_prepare_connection_hook=False):
|
||||
"""Open a fresh sqlite3 connection scoped to an ``async with`` block.
|
||||
|
||||
Intended for short-lived per-request work — for example, installing
|
||||
a ``set_authorizer`` callback derived from ``request.actor`` before
|
||||
running queries. The connection is not added to the pool, not
|
||||
cached, and is closed when the context exits.
|
||||
|
||||
Pass ``run_prepare_connection_hook=True`` to opt into the
|
||||
``prepare_connection`` plugin hook; by default it is skipped so
|
||||
these connections stay cheap.
|
||||
"""
|
||||
self._check_not_closed()
|
||||
conn = self.connect(write=write, track=False)
|
||||
self.ds._prepare_connection(
|
||||
conn, self.name, run_plugin_hook=run_prepare_connection_hook
|
||||
)
|
||||
try:
|
||||
yield RequestConnection(self, conn)
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
if self.cached_hash is not None:
|
||||
|
|
@ -931,6 +987,38 @@ def _deliver_write_result(task, result, exception):
|
|||
pass
|
||||
|
||||
|
||||
class RequestConnection:
|
||||
"""Thin async wrapper around a single sqlite3.Connection.
|
||||
|
||||
Yielded by :meth:`Database.request_connection`. Exposes ``execute`` and
|
||||
``execute_fn`` with the same semantics as :class:`Database`, but bound
|
||||
to the underlying ``connection`` so a caller can attach per-request
|
||||
state (e.g. ``set_authorizer``) without touching the pool.
|
||||
"""
|
||||
|
||||
def __init__(self, db, connection):
|
||||
self._db = db
|
||||
self.connection = connection
|
||||
|
||||
async def execute_fn(self, fn):
|
||||
return await self._db._execute_fn_on_connection(self.connection, fn)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
sql,
|
||||
params=None,
|
||||
truncate=False,
|
||||
custom_time_limit=None,
|
||||
page_size=None,
|
||||
log_sql_errors=True,
|
||||
):
|
||||
fn = self._db._make_sql_operation(
|
||||
sql, params, truncate, custom_time_limit, page_size, log_sql_errors
|
||||
)
|
||||
with trace("sql", database=self._db.name, sql=sql.strip(), params=params):
|
||||
return await self._db._execute_fn_on_connection(self.connection, fn)
|
||||
|
||||
|
||||
class QueryInterrupted(Exception):
|
||||
def __init__(self, e, sql, params):
|
||||
self.e = e
|
||||
|
|
|
|||
|
|
@ -923,3 +923,156 @@ async def test_database_close_is_idempotent(tmpdir):
|
|||
# Second call should be a no-op, not raise
|
||||
db.close()
|
||||
ds._internal_database.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_basic(db):
|
||||
async with db.request_connection() as conn:
|
||||
results = await conn.execute("select 1 + 1")
|
||||
assert results.single_value() == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_exposes_raw_sqlite_connection(db):
|
||||
async with db.request_connection() as conn:
|
||||
assert isinstance(conn.connection, sqlite3.Connection)
|
||||
# Direct execution on the raw connection works too
|
||||
assert conn.connection.execute("select 1").fetchone()[0] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_fresh_each_call(db):
|
||||
async with db.request_connection() as conn_a:
|
||||
async with db.request_connection() as conn_b:
|
||||
assert conn_a.connection is not conn_b.connection
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_closes_on_exit(db):
|
||||
async with db.request_connection() as conn:
|
||||
raw = conn.connection
|
||||
await conn.execute("select 1")
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
raw.execute("select 1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_readonly_by_default(db):
|
||||
async with db.request_connection() as conn:
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
await conn.execute(
|
||||
"create table should_not_exist (id integer primary key)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_writable():
|
||||
ds = Datasette(memory=True)
|
||||
db = ds.add_memory_database("test_request_connection_writable")
|
||||
async with db.request_connection(write=True) as conn:
|
||||
await conn.execute_fn(
|
||||
lambda c: c.execute("create table t (id integer primary key)")
|
||||
)
|
||||
await conn.execute_fn(lambda c: c.execute("insert into t (id) values (1)"))
|
||||
row = await conn.execute_fn(
|
||||
lambda c: c.execute("select id from t").fetchone()
|
||||
)
|
||||
assert row[0] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_authorizer_does_not_leak_to_pool(db):
|
||||
def deny_all(action, *args):
|
||||
return sqlite3.SQLITE_DENY
|
||||
|
||||
async with db.request_connection() as conn:
|
||||
conn.connection.set_authorizer(deny_all)
|
||||
with pytest.raises(sqlite3.DatabaseError):
|
||||
await conn.execute("select 1")
|
||||
|
||||
# Plain pool execute should still succeed
|
||||
assert (await db.execute("select 1")).single_value() == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_skips_prepare_connection_hook_by_default(db):
|
||||
# The `fixtures` db has my_plugin loaded, which would register
|
||||
# convert_units via the prepare_connection hook. The per-request
|
||||
# connection should NOT have that function available by default.
|
||||
async with db.request_connection() as conn:
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
await conn.execute("select convert_units(100, 'm', 'ft')")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_runs_prepare_connection_hook_when_opted_in(db):
|
||||
async with db.request_connection(run_prepare_connection_hook=True) as conn:
|
||||
result = await conn.execute("select convert_units(100, 'm', 'ft')")
|
||||
assert result.single_value() == pytest.approx(328.0839)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_row_factory_applied(db):
|
||||
async with db.request_connection() as conn:
|
||||
row = (await conn.execute("select 1 as one")).first()
|
||||
assert isinstance(row, sqlite3.Row)
|
||||
assert row["one"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_not_tracked_in_all_file_connections(db):
|
||||
before = list(db._all_file_connections)
|
||||
async with db.request_connection() as conn:
|
||||
pass
|
||||
assert db._all_file_connections == before
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_cleans_up_on_exception(db):
|
||||
raw = None
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
async with db.request_connection() as conn:
|
||||
raw = conn.connection
|
||||
raise RuntimeError("boom")
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
raw.execute("select 1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_execute_fn(db):
|
||||
async with db.request_connection() as conn:
|
||||
value = await conn.execute_fn(
|
||||
lambda c: c.execute("select 41 + 1").fetchone()[0]
|
||||
)
|
||||
assert value == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("disable_threads", (False, True))
|
||||
async def test_request_connection_non_threaded(disable_threads):
|
||||
if disable_threads:
|
||||
ds = Datasette(memory=True, settings={"num_sql_threads": 0})
|
||||
else:
|
||||
ds = Datasette(memory=True)
|
||||
db = ds.add_memory_database("test_request_connection_non_threaded")
|
||||
async with db.request_connection() as conn:
|
||||
assert (await conn.execute("select 1")).single_value() == 1
|
||||
assert (
|
||||
await conn.execute_fn(lambda c: c.execute("select 2").fetchone()[0])
|
||||
) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_connection_raises_after_database_closed(tmpdir):
|
||||
path = str(tmpdir / "closed_req.db")
|
||||
conn = sqlite3.connect(path)
|
||||
conn.execute("create table t (id integer primary key)")
|
||||
conn.close()
|
||||
ds = Datasette([path])
|
||||
db = ds.get_database("closed_req")
|
||||
await db.execute("select 1")
|
||||
db.close()
|
||||
with pytest.raises(DatasetteClosedError):
|
||||
async with db.request_connection() as conn:
|
||||
pass
|
||||
ds._internal_database.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue