write_wrapper plugin hook for intercepting write operations (#2636)

* Implement write_wrapper plugin hook for intercepting database writes

Add a new `write_wrapper` plugin hook that lets plugins wrap write
operations with before/after logic using a generator-based context
manager pattern. The hook receives (datasette, database, request,
transaction) and returns a generator function that takes a conn,
yields once to let the write execute, and can run cleanup after.

The write result is sent back via `generator.send()` and exceptions
are thrown via `generator.throw()`, giving plugins full visibility.

Also adds `request=None` parameter to execute_write, execute_write_fn,
execute_write_script, and execute_write_many, and threads request
through all view-layer call sites (insert, upsert, update, delete,
drop, create table, canned queries).

* Add documentation for wrap_write hook, fix lint issues

Document the wrap_write plugin hook in plugin_hooks.rst with
parameter descriptions and two examples: a simple logging wrapper
and an advanced SQLite authorizer-based table protection pattern.

Also fix black formatting and remove unused variable flagged by ruff.

* Rename wrap_write hook to write_wrapper for consistency with asgi_wrapper
* Move write_wrapper docs to just below prepare_connection
* Refactor write_wrapper tests to use pytest.parametrize

Consolidate duplicate test cases: merge before/after tests for
execute_write_fn and execute_write into one parametrized test, and
merge three parameter-passing tests into one parametrized test.

Claude Code transcript: https://gisthost.github.io/?c4c12079434e69677e4aa8ac664b21b8/index.html
This commit is contained in:
Simon Willison 2026-02-09 13:20:33 -08:00 committed by GitHub
commit 80b7f987ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 604 additions and 13 deletions

View file

@ -130,25 +130,25 @@ class Database:
for connection in self._all_file_connections:
connection.close()
async def execute_write(self, sql, params=None, block=True):
async def execute_write(self, sql, params=None, block=True, request=None):
def _inner(conn):
return conn.execute(sql, params or [])
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_write_fn(_inner, block=block)
results = await self.execute_write_fn(_inner, block=block, request=request)
return results
async def execute_write_script(self, sql, block=True):
async def execute_write_script(self, sql, block=True, request=None):
def _inner(conn):
return conn.executescript(sql)
with trace("sql", database=self.name, sql=sql.strip(), executescript=True):
results = await self.execute_write_fn(
_inner, block=block, transaction=False
_inner, block=block, transaction=False, request=request
)
return results
async def execute_write_many(self, sql, params_seq, block=True):
async def execute_write_many(self, sql, params_seq, block=True, request=None):
def _inner(conn):
count = 0
@ -163,7 +163,9 @@ class Database:
with trace(
"sql", database=self.name, sql=sql.strip(), executemany=True
) as kwargs:
results, count = await self.execute_write_fn(_inner, block=block)
results, count = await self.execute_write_fn(
_inner, block=block, request=request
)
kwargs["count"] = count
return results
@ -187,7 +189,8 @@ class Database:
# Threaded mode - send to write thread
return await self._send_to_write_thread(fn, isolated_connection=True)
async def execute_write_fn(self, fn, block=True, transaction=True):
async def execute_write_fn(self, fn, block=True, transaction=True, request=None):
fn = self._wrap_fn_with_hooks(fn, request, transaction)
if self.ds.executor is None:
# non-threaded mode
if self._write_connection is None:
@ -203,6 +206,25 @@ class Database:
fn, block=block, transaction=transaction
)
def _wrap_fn_with_hooks(self, fn, request, transaction):
from .plugins import pm
wrappers = pm.hook.write_wrapper(
datasette=self.ds,
database=self.name,
request=request,
transaction=transaction,
)
wrappers = [w for w in wrappers if w is not None]
if not wrappers:
return fn
# Build the wrapped fn by nesting context manager generators.
# The first wrapper returned by pluggy is outermost.
original_fn = fn
for wrapper_factory in reversed(wrappers):
original_fn = _apply_write_wrapper(original_fn, wrapper_factory)
return original_fn
async def _send_to_write_thread(
self, fn, block=True, isolated_connection=False, transaction=True
):
@ -680,6 +702,47 @@ class Database:
return f"<Database: {self.name}{tags_str}>"
def _apply_write_wrapper(fn, wrapper_factory):
"""Apply a single write_wrapper context manager around fn.
``wrapper_factory`` is a callable that takes ``(conn)`` and returns a
generator that yields exactly once. Code before the yield runs before
``fn(conn)``, code after the yield runs after. The result of
``fn(conn)`` is sent into the generator via ``.send()``, and any
exception raised by ``fn(conn)`` is thrown via ``.throw()``.
"""
def wrapped(conn):
gen = wrapper_factory(conn)
# Advance to the yield point (run "before" code)
try:
next(gen)
except StopIteration:
# Generator didn't yield — just run fn unchanged
return fn(conn)
# Execute the actual write
try:
result = fn(conn)
except Exception:
# Throw exception into generator so it can handle it
try:
gen.throw(*sys.exc_info())
except StopIteration:
pass
# Re-raise the original exception
raise
else:
# Send the result back through the yield
try:
gen.send(result)
except StopIteration:
pass
return result
return wrapped
class WriteTask:
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction")

View file

@ -220,3 +220,25 @@ def top_query(datasette, request, database, sql):
@hookspec
def top_canned_query(datasette, request, database, query_name):
"""HTML to include at the top of the canned query page"""
@hookspec
def write_wrapper(datasette, database, request, transaction):
"""Called when a write function is about to execute.
Return a generator function that accepts a ``conn`` argument.
The generator should ``yield`` exactly once: code before the
``yield`` runs before the write, code after the ``yield`` runs
after the write completes. The result of the write is sent
back through the ``yield``, so you can capture it with
``result = yield``.
If the write raises an exception, it is thrown into the generator
so you can handle it with a try/except around the ``yield``.
``request`` may be ``None`` for writes not originating from an
HTTP request. ``transaction`` is ``True`` if the write will
be wrapped in a transaction.
Return ``None`` to skip wrapping.
"""

View file

@ -466,7 +466,9 @@ class QueryView(View):
ok = None
redirect_url = None
try:
cursor = await db.execute_write(canned_query["sql"], params_for_query)
cursor = await db.execute_write(
canned_query["sql"], params_for_query, request=request
)
# success message can come from on_success_message or on_success_message_sql
message = None
message_type = datasette.INFO
@ -1119,7 +1121,7 @@ class TableCreateView(BaseView):
return table.schema
try:
schema = await db.execute_write_fn(create_table)
schema = await db.execute_write_fn(create_table, request=request)
except Exception as e:
return _error([str(e)])

View file

@ -245,7 +245,7 @@ class RowDeleteView(BaseView):
sqlite_utils.Database(conn)[resolved.table].delete(resolved.pk_values)
try:
await resolved.db.execute_write_fn(delete_row)
await resolved.db.execute_write_fn(delete_row, request=request)
except Exception as e:
return _error([str(e)], 500)
@ -305,7 +305,7 @@ class RowUpdateView(BaseView):
)
try:
await resolved.db.execute_write_fn(update_row)
await resolved.db.execute_write_fn(update_row, request=request)
except Exception as e:
return _error([str(e)], 400)

View file

@ -550,7 +550,7 @@ class TableInsertView(BaseView):
method_all(rows, **kwargs)
try:
rows = await db.execute_write_fn(insert_or_upsert_rows)
rows = await db.execute_write_fn(insert_or_upsert_rows, request=request)
except Exception as e:
return _error([str(e)])
result = {"ok": True}
@ -670,7 +670,7 @@ class TableDropView(BaseView):
def drop_table(conn):
sqlite_utils.Database(conn)[table_name].drop()
await db.execute_write_fn(drop_table)
await db.execute_write_fn(drop_table, request=request)
await self.ds.track_event(
DropTableEvent(
actor=request.actor, database=database_name, table=table_name

View file

@ -61,6 +61,92 @@ arguments and can be called like this::
Examples: `datasette-jellyfish <https://datasette.io/plugins/datasette-jellyfish>`__, `datasette-jq <https://datasette.io/plugins/datasette-jq>`__, `datasette-haversine <https://datasette.io/plugins/datasette-haversine>`__, `datasette-rure <https://datasette.io/plugins/datasette-rure>`__
.. _plugin_hook_write_wrapper:
write_wrapper(datasette, database, request, transaction)
--------------------------------------------------------
``datasette`` - :ref:`internals_datasette`
You can use this to access plugin configuration options via ``datasette.plugin_config(your_plugin_name)``.
``database`` - string
The name of the database being written to.
``request`` - :ref:`internals_request` or ``None``
The HTTP request that triggered this write, if available. This will be ``None`` for writes that do not originate from an HTTP request (e.g. writes triggered by plugins during startup).
``transaction`` - bool
``True`` if the write will be wrapped in a database transaction.
Return a generator function that accepts a ``conn`` argument (a SQLite connection object). The generator should ``yield`` exactly once. Code before the ``yield`` runs before the write function executes; code after the ``yield`` runs after it completes.
The result of the write function is sent back through the ``yield``, so you can capture it with ``result = yield``.
If the write function raises an exception, it is thrown into the generator so you can handle it with a ``try`` / ``except`` around the ``yield``.
Return ``None`` to skip wrapping for this particular write.
This example logs every write operation:
.. code-block:: python
from datasette import hookimpl
@hookimpl
def write_wrapper(datasette, database, request):
def wrapper(conn):
print(f"Before write to {database}")
result = yield
print(f"After write to {database}")
return wrapper
This more advanced example uses the SQLite authorizer callback to block writes to a specific table for non-admin users:
.. code-block:: python
import sqlite3
from datasette import hookimpl
WRITE_ACTIONS = (
sqlite3.SQLITE_INSERT,
sqlite3.SQLITE_UPDATE,
sqlite3.SQLITE_DELETE,
)
@hookimpl
def write_wrapper(datasette, database, request):
actor = None
if request:
actor = request.actor
if actor and actor.get("id") == "admin":
return None
def wrapper(conn):
def authorizer(
action, arg1, arg2, db_name, trigger
):
if (
action in WRITE_ACTIONS
and arg1 == "protected_table"
):
return sqlite3.SQLITE_DENY
return sqlite3.SQLITE_OK
conn.set_authorizer(authorizer)
try:
yield
finally:
conn.set_authorizer(None)
return wrapper
The ``conn`` object passed to the generator is the same connection that the write function will use. Because the generator and the write function execute together in a single call on the write thread, any state you set on the connection (authorizers, pragmas, temporary tables) is visible to the write and can be cleaned up afterwards.
When multiple plugins implement ``write_wrapper``, they are nested following pluggy's default calling convention.
.. _plugin_hook_prepare_jinja2_environment:
prepare_jinja2_environment(env, datasette)
@ -2249,3 +2335,4 @@ The plugin can then call ``datasette.track_event(...)`` to send a ``ban-user`` e
await datasette.track_event(
BanUserEvent(user={"id": 1, "username": "cleverbot"})
)

View file

@ -1524,6 +1524,36 @@ async def test_hook_register_events():
assert any(k.__name__ == "OneEvent" for k in datasette.event_classes)
@pytest.mark.asyncio
async def test_hook_write_wrapper():
datasette = Datasette(memory=True)
log = []
class WrapWritePlugin:
__name__ = "WrapWritePlugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
if database != "_memory":
return None
def wrapper(conn):
log.append("before")
yield
log.append("after")
return wrapper
pm.register(WrapWritePlugin(), name="WrapWritePluginTest")
try:
db = datasette.get_database("_memory")
await db.execute_write("create table t (id integer primary key)")
assert log == ["before", "after"]
finally:
pm.unregister(name="WrapWritePluginTest")
@pytest.mark.asyncio
async def test_hook_register_actions_view_collection():
datasette = Datasette(memory=True, plugins_dir=PLUGINS_DIR)

387
tests/test_write_wrapper.py Normal file
View file

@ -0,0 +1,387 @@
"""
Tests for the write_wrapper plugin hook.
"""
from datasette.app import Datasette
from datasette.hookspecs import hookimpl
from datasette.plugins import pm
import pytest
import time
@pytest.fixture
def datasette(tmp_path):
db_path = str(tmp_path / "test.db")
ds = Datasette([db_path])
return ds
@pytest.mark.asyncio
@pytest.mark.parametrize(
"use_execute_write",
(False, True),
ids=["execute_write_fn", "execute_write"],
)
async def test_write_wrapper_before_and_after(datasette, use_execute_write):
"""Test that code before and after yield both execute."""
log = []
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
log.append("before")
yield
log.append("after")
return wrapper
pm.register(Plugin(), name="test_before_after")
try:
db = datasette.get_database("test")
if use_execute_write:
await db.execute_write(
"create table if not exists t (id integer primary key)"
)
else:
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t (id integer primary key)"
)
)
assert log == ["before", "after"]
finally:
pm.unregister(name="test_before_after")
@pytest.mark.asyncio
async def test_write_wrapper_receives_result_via_yield(datasette):
"""Test that the result of fn(conn) is sent back through yield."""
captured = {}
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
result = yield
captured["result"] = result
return wrapper
pm.register(Plugin(), name="test_result")
try:
db = datasette.get_database("test")
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t2 (id integer primary key)"
)
)
assert "result" in captured
# Should be a sqlite3 Cursor
assert captured["result"] is not None
finally:
pm.unregister(name="test_result")
@pytest.mark.asyncio
async def test_write_wrapper_exception_thrown_into_generator(datasette):
"""Test that exceptions from fn(conn) are thrown into the generator."""
caught = {}
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
try:
yield
except Exception as e:
caught["error"] = e
return wrapper
pm.register(Plugin(), name="test_exception")
try:
db = datasette.get_database("test")
with pytest.raises(Exception, match="deliberate"):
await db.execute_write_fn(
lambda conn: (_ for _ in ()).throw(Exception("deliberate"))
)
assert "error" in caught
assert str(caught["error"]) == "deliberate"
finally:
pm.unregister(name="test_exception")
@pytest.mark.asyncio
async def test_write_wrapper_conn_is_usable(datasette):
"""Test that the conn passed to the wrapper can execute SQL."""
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
conn.execute("create table if not exists hook_log (msg text)")
conn.execute("insert into hook_log values ('before')")
yield
conn.execute("insert into hook_log values ('after')")
return wrapper
pm.register(Plugin(), name="test_conn")
try:
db = datasette.get_database("test")
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t3 (id integer primary key)"
)
)
result = await db.execute("select msg from hook_log order by rowid")
messages = [row[0] for row in result.rows]
assert messages == ["before", "after"]
finally:
pm.unregister(name="test_conn")
@pytest.mark.asyncio
async def test_write_wrapper_multiple_plugins_nest(datasette):
"""Test that multiple write_wrapper plugins nest correctly."""
log = []
class PluginA:
__name__ = "PluginA"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
log.append("A-before")
yield
log.append("A-after")
return wrapper
class PluginB:
__name__ = "PluginB"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
def wrapper(conn):
log.append("B-before")
yield
log.append("B-after")
return wrapper
pm.register(PluginA(), name="PluginA")
pm.register(PluginB(), name="PluginB")
try:
db = datasette.get_database("test")
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t4 (id integer primary key)"
)
)
assert set(log) == {"A-before", "A-after", "B-before", "B-after"}
# Verify proper nesting: each plugin's before/after should be
# symmetric around the write
a_before = log.index("A-before")
a_after = log.index("A-after")
b_before = log.index("B-before")
b_after = log.index("B-after")
if a_before < b_before:
assert a_after > b_after, "A is outer so A-after should come after B-after"
else:
assert b_after > a_after, "B is outer so B-after should come after A-after"
finally:
pm.unregister(name="PluginA")
pm.unregister(name="PluginB")
@pytest.mark.asyncio
async def test_write_wrapper_return_none_skips(datasette):
"""Test that returning None from write_wrapper means no wrapping."""
log = []
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
log.append("hook-called")
return None
pm.register(Plugin(), name="test_skip")
try:
db = datasette.get_database("test")
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t5 (id integer primary key)"
)
)
assert log == ["hook-called"]
finally:
pm.unregister(name="test_skip")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"request_value,transaction_value,expected_request,expected_transaction",
(
("fake-request", True, "fake-request", True),
(None, True, None, True),
(None, False, None, False),
),
ids=["with-request", "request-none-by-default", "transaction-false"],
)
async def test_write_wrapper_hook_parameters(
datasette,
request_value,
transaction_value,
expected_request,
expected_transaction,
):
"""Test that request and transaction parameters are passed through."""
captured = {}
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
captured["request"] = request
captured["database"] = database
captured["transaction"] = transaction
pm.register(Plugin(), name="test_params")
try:
db = datasette.get_database("test")
kwargs = {"transaction": transaction_value}
if request_value is not None:
kwargs["request"] = request_value
await db.execute_write_fn(
lambda conn: conn.execute(
"create table if not exists t6 (id integer primary key)"
),
**kwargs,
)
assert captured["request"] == expected_request
assert captured["database"] == "test"
assert captured["transaction"] == expected_transaction
finally:
pm.unregister(name="test_params")
@pytest.mark.asyncio
async def test_write_wrapper_via_api(tmp_path):
"""Test that write_wrapper fires for API write operations."""
log = []
db_path = str(tmp_path / "test.db")
ds = Datasette([db_path], pdb=False)
ds.root_enabled = True
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
if database != "test":
return None
def wrapper(conn):
log.append("before")
yield
log.append("after")
return wrapper
pm.register(Plugin(), name="test_api")
try:
db = ds.get_database("test")
await db.execute_write(
"create table if not exists api_test (id integer primary key, name text)"
)
log.clear()
token = "dstok_{}".format(
ds.sign(
{"a": "root", "token": "dstok", "t": int(time.time())},
namespace="token",
)
)
response = await ds.client.post(
"/test/api_test/-/insert",
json={"row": {"name": "test"}, "return": True},
headers={
"Authorization": "Bearer {}".format(token),
"Content-Type": "application/json",
},
)
assert response.status_code == 201, response.json()
assert log == ["before", "after"]
finally:
pm.unregister(name="test_api")
@pytest.mark.asyncio
async def test_write_wrapper_change_group_pattern(datasette):
"""Test the motivating use case: activating a change group around a write."""
db = datasette.get_database("test")
await db.execute_write(
"create table if not exists groups (id integer primary key, current integer)"
)
await db.execute_write(
"create table if not exists data (id integer primary key, value text)"
)
await db.execute_write("insert into groups (id, current) values (1, null)")
class Plugin:
__name__ = "Plugin"
@staticmethod
@hookimpl
def write_wrapper(datasette, database, request, transaction):
if request and getattr(request, "group_id", None):
group_id = request.group_id
def wrapper(conn):
conn.execute(
"update groups set current = 1 where id = ?", [group_id]
)
yield
conn.execute("update groups set current = null where current = 1")
return wrapper
pm.register(Plugin(), name="test_change_group")
try:
class FakeRequest:
group_id = 1
await db.execute_write_fn(
lambda conn: conn.execute("insert into data (value) values ('test')"),
request=FakeRequest(),
)
result = await db.execute("select current from groups where id = 1")
assert result.rows[0][0] is None
finally:
pm.unregister(name="test_change_group")