From 312f41b0c28eea66c76ab8dfac11db76aaf0000a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 30 Mar 2026 11:20:46 -0700 Subject: [PATCH] RenameTableEvent, plus write connection track_event() mechanism (#2682) * Add track_event callback to execute_write_fn and write_wrapper Allows write functions and write_wrapper generators to queue events during a write operation that are dispatched after successful commit. The fn or wrapper can optionally accept a `track_event` parameter (detected via call_with_supported_arguments). Events are discarded if the write raises an exception. Does not yet handle the block=False (non-blocking) case - events queued during non-blocking writes are currently silently discarded. Refs https://github.com/simonw/datasette/issues/2681 * Dispatch track_event events for non-blocking (block=False) writes Spawns a background asyncio task that awaits the write thread's reply queue and dispatches pending events after a successful non-blocking write. Events are still discarded if the write raises an exception. Refs https://github.com/simonw/datasette/issues/2681 * Warn that events won't fire for other processes Refs https://github.com/simonw/datasette/issues/2681#issuecomment-4157118662 --- datasette/database.py | 67 ++++++--- datasette/events.py | 58 ++++++++ datasette/hookspecs.py | 18 ++- docs/events.md | 2 + docs/internals.rst | 30 ++++ docs/plugin_hooks.rst | 4 +- docs/plugins.rst | 3 +- tests/test_write_wrapper.py | 265 ++++++++++++++++++++++++++++++++++++ 8 files changed, 423 insertions(+), 24 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index fcf69c7f..ffbbebba 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -10,6 +10,7 @@ import uuid from .tracer import trace from .utils import ( + call_with_supported_arguments, detect_fts, detect_primary_keys, detect_spatialite, @@ -190,7 +191,12 @@ class Database: return await self._send_to_write_thread(fn, isolated_connection=True) async def execute_write_fn(self, fn, block=True, transaction=True, request=None): - fn = self._wrap_fn_with_hooks(fn, request, transaction) + pending_events = [] + + def track_event(event): + pending_events.append(event) + + fn = self._wrap_fn_with_hooks(fn, request, transaction, track_event) if self.ds.executor is None: # non-threaded mode if self._write_connection is None: @@ -198,17 +204,44 @@ class Database: self.ds._prepare_connection(self._write_connection, self.name) if transaction: with self._write_connection: - return fn(self._write_connection) + result = fn(self._write_connection) else: - return fn(self._write_connection) + result = fn(self._write_connection) else: - return await self._send_to_write_thread( + result = await self._send_to_write_thread( fn, block=block, transaction=transaction ) + if block: + for event in pending_events: + await self.ds.track_event(event) + else: + # For non-blocking writes, spawn a background task to + # dispatch events after the write thread completes + task_id, reply_queue = result - def _wrap_fn_with_hooks(self, fn, request, transaction): + async def _dispatch_events_after_write(): + write_result = await reply_queue.async_q.get() + if not isinstance(write_result, Exception): + for event in pending_events: + await self.ds.track_event(event) + + asyncio.ensure_future(_dispatch_events_after_write()) + result = task_id + return result + + def _wrap_fn_with_hooks(self, fn, request, transaction, track_event): from .plugins import pm + # Wrap fn so it receives track_event if its signature supports it + original_fn = fn + + def fn_with_track_event(conn): + return call_with_supported_arguments( + original_fn, conn=conn, track_event=track_event + ) + + fn = fn_with_track_event + wrappers = pm.hook.write_wrapper( datasette=self.ds, database=self.name, @@ -220,10 +253,9 @@ class Database: return fn # Build the wrapped fn by nesting context manager generators. # The first wrapper returned by pluggy is outermost. - original_fn = fn for wrapper_factory in reversed(wrappers): - original_fn = _apply_write_wrapper(original_fn, wrapper_factory) - return original_fn + fn = _apply_write_wrapper(fn, wrapper_factory, track_event) + return fn async def _send_to_write_thread( self, fn, block=True, isolated_connection=False, transaction=True @@ -250,7 +282,7 @@ class Database: else: return result else: - return task_id + return task_id, reply_queue def _execute_writes(self): # Infinite looping thread that protects the single write connection @@ -682,18 +714,21 @@ class Database: return f"" -def _apply_write_wrapper(fn, wrapper_factory): +def _apply_write_wrapper(fn, wrapper_factory, track_event): """Apply a single write_wrapper context manager around fn. - ``wrapper_factory`` is a callable that takes ``(conn)`` and returns a - generator that yields exactly once. Code before the yield runs before - ``fn(conn)``, code after the yield runs after. The result of - ``fn(conn)`` is sent into the generator via ``.send()``, and any - exception raised by ``fn(conn)`` is thrown via ``.throw()``. + ``wrapper_factory`` is a callable that takes ``(conn)`` and optionally + ``track_event``, and returns a generator that yields exactly once. + Code before the yield runs before ``fn(conn)``, code after the yield + runs after. The result of ``fn(conn)`` is sent into the generator + via ``.send()``, and any exception raised by ``fn(conn)`` is thrown + via ``.throw()``. """ def wrapped(conn): - gen = wrapper_factory(conn) + gen = call_with_supported_arguments( + wrapper_factory, conn=conn, track_event=track_event + ) # Advance to the yield point (run "before" code) try: next(gen) diff --git a/datasette/events.py b/datasette/events.py index 5cd5ba3d..e8786da9 100644 --- a/datasette/events.py +++ b/datasette/events.py @@ -199,6 +199,27 @@ class UpdateRowEvent(Event): pks: list +@dataclass +class RenameTableEvent(Event): + """ + Event name: ``rename-table`` + + A table has been renamed. + + :ivar database: The name of the database containing the renamed table. + :type database: str + :ivar old_table: The previous name of the table. + :type old_table: str + :ivar new_table: The new name of the table. + :type new_table: str + """ + + name = "rename-table" + database: str + old_table: str + new_table: str + + @dataclass class DeleteRowEvent(Event): """ @@ -219,6 +240,42 @@ class DeleteRowEvent(Event): pks: list +@hookimpl +def write_wrapper(datasette, database, request, transaction): + def wrapper(conn, track_event): + # Snapshot rootpage -> name before the write + before = { + row[1]: row[0] + for row in conn.execute( + "select name, rootpage from sqlite_master" + " where type='table' and rootpage != 0" + ).fetchall() + } + yield + # Snapshot rootpage -> name after the write + after = { + row[1]: row[0] + for row in conn.execute( + "select name, rootpage from sqlite_master" + " where type='table' and rootpage != 0" + ).fetchall() + } + # Detect renames: same rootpage, different name + for rootpage, old_name in before.items(): + new_name = after.get(rootpage) + if new_name and new_name != old_name: + track_event( + RenameTableEvent( + actor=request.actor if request else None, + database=database, + old_table=old_name, + new_table=new_name, + ) + ) + + return wrapper + + @hookimpl def register_events(): return [ @@ -227,6 +284,7 @@ def register_events(): CreateTableEvent, CreateTokenEvent, AlterTableEvent, + RenameTableEvent, DropTableEvent, InsertRowsEvent, UpsertRowsEvent, diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 2ab9d0c5..7af9cbce 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -246,12 +246,18 @@ def register_token_handler(datasette): def write_wrapper(datasette, database, request, transaction): """Called when a write function is about to execute. - Return a generator function that accepts a ``conn`` argument. - The generator should ``yield`` exactly once: code before the - ``yield`` runs before the write, code after the ``yield`` runs - after the write completes. The result of the write is sent - back through the ``yield``, so you can capture it with - ``result = yield``. + Return a generator function that accepts a ``conn`` argument and + optionally a ``track_event`` argument. The generator should + ``yield`` exactly once: code before the ``yield`` runs before + the write, code after the ``yield`` runs after the write + completes. The result of the write is sent back through the + ``yield``, so you can capture it with ``result = yield``. + + If your generator accepts ``track_event``, you can call + ``track_event(event)`` to queue an event that will be dispatched + via ``datasette.track_event()`` after the write commits + successfully. Events are discarded if the write raises an + exception. If the write raises an exception, it is thrown into the generator so you can handle it with a try/except around the ``yield``. diff --git a/docs/events.md b/docs/events.md index 399317e9..f63d1893 100644 --- a/docs/events.md +++ b/docs/events.md @@ -5,6 +5,8 @@ Datasette includes a mechanism for tracking events that occur while the software The core Datasette application triggers events when certain things happen. This page describes those events. +Note that these events will *not* fire for changes made to a SQLite database by a process other than Datasette itself. + Plugins can listen for events using the {ref}`plugin_hook_track_event` plugin hook, which will be called with instances of the following classes - or additional classes {ref}`registered by other plugins `. ```{eval-rst} diff --git a/docs/internals.rst b/docs/internals.rst index 829b4dd4..3b65d57a 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1739,6 +1739,36 @@ For example: except Exception as e: print("An error occurred:", e) +Your function can optionally accept a ``track_event`` parameter in addition to ``conn``. If it does, it will be passed a callable that can be used to queue events for dispatch after the write transaction commits successfully. Events queued this way are discarded if the write raises an exception. + +.. code-block:: python + + from datasette.events import AlterTableEvent + + + def my_write(conn, track_event): + before_schema = conn.execute( + "select sql from sqlite_master where name = 'my_table'" + ).fetchone()[0] + conn.execute( + "alter table my_table add column new_col text" + ) + after_schema = conn.execute( + "select sql from sqlite_master where name = 'my_table'" + ).fetchone()[0] + track_event( + AlterTableEvent( + actor=None, + database="mydb", + table="my_table", + before_schema=before_schema, + after_schema=after_schema, + ) + ) + + + await database.execute_write_fn(my_write) + The value returned from ``await database.execute_write_fn(...)`` will be the return value from your function. If your function raises an exception that exception will be propagated up to the ``await`` line. diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index fdc392cb..79b3e669 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -78,12 +78,14 @@ write_wrapper(datasette, database, request, transaction) ``transaction`` - bool ``True`` if the write will be wrapped in a database transaction. -Return a generator function that accepts a ``conn`` argument (a SQLite connection object). The generator should ``yield`` exactly once. Code before the ``yield`` runs before the write function executes; code after the ``yield`` runs after it completes. +Return a generator function that accepts a ``conn`` argument (a SQLite connection object) and optionally a ``track_event`` argument. The generator should ``yield`` exactly once. Code before the ``yield`` runs before the write function executes; code after the ``yield`` runs after it completes. The result of the write function is sent back through the ``yield``, so you can capture it with ``result = yield``. If the write function raises an exception, it is thrown into the generator so you can handle it with a ``try`` / ``except`` around the ``yield``. +If your generator accepts ``track_event``, you can call ``track_event(event)`` to queue an event that will be dispatched via :ref:`datasette.track_event() ` after the write commits successfully. Events are discarded if the write raises an exception. + Return ``None`` to skip wrapping for this particular write. This example logs every write operation: diff --git a/docs/plugins.rst b/docs/plugins.rst index 03cbedeb..eb7b06e1 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -261,7 +261,8 @@ If you run ``datasette plugins --all`` it will include default plugins that ship "templates": false, "version": null, "hooks": [ - "register_events" + "register_events", + "write_wrapper" ] }, { diff --git a/tests/test_write_wrapper.py b/tests/test_write_wrapper.py index 55e0461e..c2ceb344 100644 --- a/tests/test_write_wrapper.py +++ b/tests/test_write_wrapper.py @@ -2,7 +2,9 @@ Tests for the write_wrapper plugin hook. """ +from dataclasses import dataclass from datasette.app import Datasette +from datasette.events import Event from datasette.hookspecs import hookimpl from datasette.plugins import pm import pytest @@ -10,6 +12,12 @@ import sqlite3 import time +@dataclass +class DummyEvent(Event): + name = "dummy" + message: str + + @pytest.fixture def datasette(tmp_path): db_path = str(tmp_path / "test.db") @@ -477,3 +485,260 @@ async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny assert result.rows[0][0] == "test" finally: pm.unregister(name="test_set_authorizer") + + +# --- Tests for track_event callback --- + + +@pytest.fixture +def ds_with_event_tracking(tmp_path): + """Datasette instance that records tracked events and registers DummyEvent.""" + db_path = str(tmp_path / "test.db") + ds = Datasette([db_path]) + ds._tracked_events = [] + # Set event_classes directly to avoid needing invoke_startup + ds.event_classes = (DummyEvent,) + + async def recording_track_event(event): + ds._tracked_events.append(event) + + ds.track_event = recording_track_event + + yield ds + + +@pytest.mark.asyncio +async def test_track_event_in_write_fn(ds_with_event_tracking): + """fn(conn, track_event) can queue events that are dispatched after commit.""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + def my_write(conn, track_event): + conn.execute("create table if not exists te1 (id integer primary key)") + track_event(DummyEvent(actor=None, message="hello")) + + await db.execute_write_fn(my_write) + assert len(ds._tracked_events) == 1 + assert ds._tracked_events[0].message == "hello" + + +@pytest.mark.asyncio +async def test_track_event_discarded_on_exception(ds_with_event_tracking): + """Events are discarded if the write fn raises an exception.""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + def my_write(conn, track_event): + track_event(DummyEvent(actor=None, message="should not fire")) + raise ValueError("deliberate error") + + with pytest.raises(ValueError, match="deliberate"): + await db.execute_write_fn(my_write) + assert len(ds._tracked_events) == 0 + + +@pytest.mark.asyncio +async def test_track_event_existing_fn_signature_still_works(ds_with_event_tracking): + """Existing fn(conn) signatures continue to work without track_event.""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists te2 (id integer primary key)" + ) + ) + # No events, no errors + assert len(ds._tracked_events) == 0 + + +@pytest.mark.asyncio +async def test_track_event_in_write_wrapper(ds_with_event_tracking): + """write_wrapper generator with (conn, track_event) can queue events.""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn, track_event): + track_event(DummyEvent(actor=None, message="from wrapper before")) + yield + track_event(DummyEvent(actor=None, message="from wrapper after")) + + return wrapper + + pm.register(Plugin(), name="test_track_wrapper") + try: + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists te3 (id integer primary key)" + ) + ) + assert len(ds._tracked_events) == 2 + assert ds._tracked_events[0].message == "from wrapper before" + assert ds._tracked_events[1].message == "from wrapper after" + finally: + pm.unregister(name="test_track_wrapper") + + +@pytest.mark.asyncio +async def test_track_event_shared_between_fn_and_wrapper(ds_with_event_tracking): + """Both fn and wrapper can queue events, all dispatched in order.""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn, track_event): + track_event(DummyEvent(actor=None, message="wrapper-before")) + yield + track_event(DummyEvent(actor=None, message="wrapper-after")) + + return wrapper + + pm.register(Plugin(), name="test_track_shared") + try: + + def my_write(conn, track_event): + conn.execute("create table if not exists te4 (id integer primary key)") + track_event(DummyEvent(actor=None, message="from-fn")) + + await db.execute_write_fn(my_write) + messages = [e.message for e in ds._tracked_events] + assert messages == ["wrapper-before", "from-fn", "wrapper-after"] + finally: + pm.unregister(name="test_track_shared") + + +@pytest.mark.asyncio +async def test_track_event_with_block_false(ds_with_event_tracking): + """Events are dispatched even when block=False (non-blocking writes).""" + ds = ds_with_event_tracking + db = ds.get_database("test") + + def my_write(conn, track_event): + conn.execute("create table if not exists te5 (id integer primary key)") + track_event(DummyEvent(actor=None, message="non-blocking")) + + task_id = await db.execute_write_fn(my_write, block=False) + assert task_id is not None + + # Give the background task time to complete + import asyncio + + for _ in range(50): + if ds._tracked_events: + break + await asyncio.sleep(0.01) + + assert len(ds._tracked_events) == 1 + assert ds._tracked_events[0].message == "non-blocking" + + +# --- Tests for RenameTableEvent detection --- + + +@pytest.fixture +def ds_for_rename(tmp_path): + """Datasette instance that records tracked events for rename detection tests.""" + from datasette.events import RenameTableEvent + + db_path = str(tmp_path / "test.db") + ds = Datasette([db_path]) + ds._tracked_events = [] + ds.event_classes = (RenameTableEvent,) + + async def recording_track_event(event): + ds._tracked_events.append(event) + + ds.track_event = recording_track_event + return ds + + +@pytest.mark.asyncio +async def test_rename_table_fires_event(ds_for_rename): + """Renaming a table via ALTER TABLE fires a RenameTableEvent.""" + from datasette.events import RenameTableEvent + + ds = ds_for_rename + db = ds.get_database("test") + + await db.execute_write("create table old_name (id integer primary key)") + + def rename(conn): + conn.execute("alter table old_name rename to new_name") + + await db.execute_write_fn(rename) + + rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] + assert len(rename_events) == 1 + assert rename_events[0].old_table == "old_name" + assert rename_events[0].new_table == "new_name" + assert rename_events[0].database == "test" + + +@pytest.mark.asyncio +async def test_no_rename_event_for_regular_writes(ds_for_rename): + """Regular writes (CREATE, INSERT) do not fire RenameTableEvent.""" + from datasette.events import RenameTableEvent + + ds = ds_for_rename + db = ds.get_database("test") + + await db.execute_write("create table t (id integer primary key)") + await db.execute_write_fn(lambda conn: conn.execute("insert into t values (1)")) + + rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] + assert len(rename_events) == 0 + + +@pytest.mark.asyncio +async def test_no_rename_event_on_rollback(ds_for_rename): + """RenameTableEvent is not fired if the write raises an exception.""" + from datasette.events import RenameTableEvent + + ds = ds_for_rename + db = ds.get_database("test") + + await db.execute_write("create table rollback_test (id integer primary key)") + + def rename_then_fail(conn): + conn.execute("alter table rollback_test rename to renamed") + raise ValueError("deliberate error") + + with pytest.raises(ValueError, match="deliberate"): + await db.execute_write_fn(rename_then_fail) + + rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] + assert len(rename_events) == 0 + + +@pytest.mark.asyncio +async def test_multiple_renames_in_one_write(ds_for_rename): + """Multiple renames in a single write fire multiple RenameTableEvents.""" + from datasette.events import RenameTableEvent + + ds = ds_for_rename + db = ds.get_database("test") + + await db.execute_write("create table alpha (id integer primary key)") + await db.execute_write("create table beta (id integer primary key)") + + def rename_both(conn): + conn.execute("alter table alpha rename to alpha2") + conn.execute("alter table beta rename to beta2") + + await db.execute_write_fn(rename_both) + + rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] + assert len(rename_events) == 2 + names = {(e.old_table, e.new_table) for e in rename_events} + assert names == {("alpha", "alpha2"), ("beta", "beta2")}