""" Tests for the write_wrapper plugin hook. """ from dataclasses import dataclass from datasette.app import Datasette from datasette.events import Event from datasette.hookspecs import hookimpl from datasette.plugins import pm import pytest import sqlite3 import time @dataclass class DummyEvent(Event): name = "dummy" message: str @pytest.fixture def datasette(tmp_path): db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) return ds @pytest.mark.asyncio @pytest.mark.parametrize( "use_execute_write", (False, True), ids=["execute_write_fn", "execute_write"], ) async def test_write_wrapper_before_and_after(datasette, use_execute_write): """Test that code before and after yield both execute.""" log = [] class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("before") yield log.append("after") return wrapper pm.register(Plugin(), name="test_before_after") try: db = datasette.get_database("test") if use_execute_write: await db.execute_write( "create table if not exists t (id integer primary key)" ) else: await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t (id integer primary key)" ) ) assert log == ["before", "after"] finally: pm.unregister(name="test_before_after") @pytest.mark.asyncio async def test_write_wrapper_receives_result_via_yield(datasette): """Test that the result of fn(conn) is sent back through yield.""" captured = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): result = yield captured["result"] = result return wrapper pm.register(Plugin(), name="test_result") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t2 (id integer primary key)" ) ) assert "result" in captured # Should be a sqlite3 Cursor assert captured["result"] is not None finally: pm.unregister(name="test_result") @pytest.mark.asyncio async def test_write_wrapper_exception_thrown_into_generator(datasette): """Test that exceptions from fn(conn) are thrown into the generator.""" caught = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): try: yield except Exception as e: caught["error"] = e return wrapper pm.register(Plugin(), name="test_exception") try: db = datasette.get_database("test") with pytest.raises(Exception, match="deliberate"): await db.execute_write_fn( lambda conn: (_ for _ in ()).throw(Exception("deliberate")) ) assert "error" in caught assert str(caught["error"]) == "deliberate" finally: pm.unregister(name="test_exception") @pytest.mark.asyncio async def test_write_wrapper_conn_is_usable(datasette): """Test that the conn passed to the wrapper can execute SQL.""" class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): conn.execute("create table if not exists hook_log (msg text)") conn.execute("insert into hook_log values ('before')") yield conn.execute("insert into hook_log values ('after')") return wrapper pm.register(Plugin(), name="test_conn") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t3 (id integer primary key)" ) ) result = await db.execute("select msg from hook_log order by rowid") messages = [row[0] for row in result.rows] assert messages == ["before", "after"] finally: pm.unregister(name="test_conn") @pytest.mark.asyncio async def test_write_wrapper_multiple_plugins_nest(datasette): """Test that multiple write_wrapper plugins nest correctly.""" log = [] class PluginA: __name__ = "PluginA" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("A-before") yield log.append("A-after") return wrapper class PluginB: __name__ = "PluginB" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("B-before") yield log.append("B-after") return wrapper pm.register(PluginA(), name="PluginA") pm.register(PluginB(), name="PluginB") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t4 (id integer primary key)" ) ) assert set(log) == {"A-before", "A-after", "B-before", "B-after"} # Verify proper nesting: each plugin's before/after should be # symmetric around the write a_before = log.index("A-before") a_after = log.index("A-after") b_before = log.index("B-before") b_after = log.index("B-after") if a_before < b_before: assert a_after > b_after, "A is outer so A-after should come after B-after" else: assert b_after > a_after, "B is outer so B-after should come after A-after" finally: pm.unregister(name="PluginA") pm.unregister(name="PluginB") @pytest.mark.asyncio async def test_write_wrapper_return_none_skips(datasette): """Test that returning None from write_wrapper means no wrapping.""" log = [] class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): log.append("hook-called") return None pm.register(Plugin(), name="test_skip") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t5 (id integer primary key)" ) ) assert log == ["hook-called"] finally: pm.unregister(name="test_skip") @pytest.mark.asyncio @pytest.mark.parametrize( "request_value,transaction_value,expected_request,expected_transaction", ( ("fake-request", True, "fake-request", True), (None, True, None, True), (None, False, None, False), ), ids=["with-request", "request-none-by-default", "transaction-false"], ) async def test_write_wrapper_hook_parameters( datasette, request_value, transaction_value, expected_request, expected_transaction, ): """Test that request and transaction parameters are passed through.""" captured = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): captured["request"] = request captured["database"] = database captured["transaction"] = transaction pm.register(Plugin(), name="test_params") try: db = datasette.get_database("test") kwargs = {"transaction": transaction_value} if request_value is not None: kwargs["request"] = request_value await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t6 (id integer primary key)" ), **kwargs, ) assert captured["request"] == expected_request assert captured["database"] == "test" assert captured["transaction"] == expected_transaction finally: pm.unregister(name="test_params") @pytest.mark.asyncio async def test_write_wrapper_via_api(tmp_path): """Test that write_wrapper fires for API write operations.""" log = [] db_path = str(tmp_path / "test.db") ds = Datasette([db_path], pdb=False) ds.root_enabled = True class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): if database != "test": return None def wrapper(conn): log.append("before") yield log.append("after") return wrapper pm.register(Plugin(), name="test_api") try: db = ds.get_database("test") await db.execute_write( "create table if not exists api_test (id integer primary key, name text)" ) log.clear() token = "dstok_{}".format( ds.sign( {"a": "root", "token": "dstok", "t": int(time.time())}, namespace="token", ) ) response = await ds.client.post( "/test/api_test/-/insert", json={"row": {"name": "test"}, "return": True}, headers={ "Authorization": "Bearer {}".format(token), "Content-Type": "application/json", }, ) assert response.status_code == 201, response.json() assert log == ["before", "after"] finally: pm.unregister(name="test_api") @pytest.mark.asyncio async def test_write_wrapper_change_group_pattern(datasette): """Test the motivating use case: activating a change group around a write.""" db = datasette.get_database("test") await db.execute_write( "create table if not exists groups (id integer primary key, current integer)" ) await db.execute_write( "create table if not exists data (id integer primary key, value text)" ) await db.execute_write("insert into groups (id, current) values (1, null)") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): if request and getattr(request, "group_id", None): group_id = request.group_id def wrapper(conn): conn.execute( "update groups set current = 1 where id = ?", [group_id] ) yield conn.execute("update groups set current = null where current = 1") return wrapper pm.register(Plugin(), name="test_change_group") try: class FakeRequest: group_id = 1 await db.execute_write_fn( lambda conn: conn.execute("insert into data (value) values ('test')"), request=FakeRequest(), ) result = await db.execute("select current from groups where id = 1") assert result.rows[0][0] is None finally: pm.unregister(name="test_change_group") WRITE_ACTIONS = ( sqlite3.SQLITE_INSERT, sqlite3.SQLITE_UPDATE, sqlite3.SQLITE_DELETE, ) @pytest.mark.asyncio @pytest.mark.parametrize( "actor,table,should_deny", ( (None, "protected_table", True), ({"id": "regular"}, "protected_table", True), ({"id": "admin"}, "protected_table", False), (None, "other_table", False), ({"id": "regular"}, "other_table", False), ), ids=[ "no-actor-protected", "regular-user-protected", "admin-protected", "no-actor-other", "regular-user-other", ], ) async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny): """Test the docs example that uses set_authorizer to block writes to a protected table.""" db = datasette.get_database("test") await db.execute_write( "create table if not exists protected_table (id integer primary key, value text)" ) await db.execute_write( "create table if not exists other_table (id integer primary key, value text)" ) class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): actor = None if request: actor = request.actor if actor and actor.get("id") == "admin": return None def wrapper(conn): def authorizer(action, arg1, arg2, db_name, trigger): if action in WRITE_ACTIONS and arg1 == "protected_table": return sqlite3.SQLITE_DENY return sqlite3.SQLITE_OK conn.set_authorizer(authorizer) try: yield finally: conn.set_authorizer(lambda *args: sqlite3.SQLITE_OK) return wrapper class FakeRequest: def __init__(self, actor): self.actor = actor pm.register(Plugin(), name="test_set_authorizer") try: request = FakeRequest(actor) if should_deny: with pytest.raises(Exception): await db.execute_write_fn( lambda conn: conn.execute( f"insert into {table} (value) values ('test')" ), request=request, ) else: await db.execute_write_fn( lambda conn: conn.execute( f"insert into {table} (value) values ('test')" ), request=request, ) result = await db.execute( f"select value from {table} order by rowid desc limit 1" ) assert result.rows[0][0] == "test" finally: pm.unregister(name="test_set_authorizer") # --- Tests for track_event callback --- @pytest.fixture def ds_with_event_tracking(tmp_path): """Datasette instance that records tracked events and registers DummyEvent.""" db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) ds._tracked_events = [] # Set event_classes directly to avoid needing invoke_startup ds.event_classes = (DummyEvent,) async def recording_track_event(event): ds._tracked_events.append(event) ds.track_event = recording_track_event yield ds @pytest.mark.asyncio async def test_track_event_in_write_fn(ds_with_event_tracking): """fn(conn, track_event) can queue events that are dispatched after commit.""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): conn.execute("create table if not exists te1 (id integer primary key)") track_event(DummyEvent(actor=None, message="hello")) await db.execute_write_fn(my_write) assert len(ds._tracked_events) == 1 assert ds._tracked_events[0].message == "hello" @pytest.mark.asyncio async def test_track_event_discarded_on_exception(ds_with_event_tracking): """Events are discarded if the write fn raises an exception.""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): track_event(DummyEvent(actor=None, message="should not fire")) raise ValueError("deliberate error") with pytest.raises(ValueError, match="deliberate"): await db.execute_write_fn(my_write) assert len(ds._tracked_events) == 0 @pytest.mark.asyncio async def test_track_event_existing_fn_signature_still_works(ds_with_event_tracking): """Existing fn(conn) signatures continue to work without track_event.""" ds = ds_with_event_tracking db = ds.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists te2 (id integer primary key)" ) ) # No events, no errors assert len(ds._tracked_events) == 0 @pytest.mark.asyncio async def test_track_event_in_write_wrapper(ds_with_event_tracking): """write_wrapper generator with (conn, track_event) can queue events.""" ds = ds_with_event_tracking db = ds.get_database("test") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn, track_event): track_event(DummyEvent(actor=None, message="from wrapper before")) yield track_event(DummyEvent(actor=None, message="from wrapper after")) return wrapper pm.register(Plugin(), name="test_track_wrapper") try: await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists te3 (id integer primary key)" ) ) assert len(ds._tracked_events) == 2 assert ds._tracked_events[0].message == "from wrapper before" assert ds._tracked_events[1].message == "from wrapper after" finally: pm.unregister(name="test_track_wrapper") @pytest.mark.asyncio async def test_track_event_shared_between_fn_and_wrapper(ds_with_event_tracking): """Both fn and wrapper can queue events, all dispatched in order.""" ds = ds_with_event_tracking db = ds.get_database("test") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn, track_event): track_event(DummyEvent(actor=None, message="wrapper-before")) yield track_event(DummyEvent(actor=None, message="wrapper-after")) return wrapper pm.register(Plugin(), name="test_track_shared") try: def my_write(conn, track_event): conn.execute("create table if not exists te4 (id integer primary key)") track_event(DummyEvent(actor=None, message="from-fn")) await db.execute_write_fn(my_write) messages = [e.message for e in ds._tracked_events] assert messages == ["wrapper-before", "from-fn", "wrapper-after"] finally: pm.unregister(name="test_track_shared") @pytest.mark.asyncio async def test_track_event_with_block_false(ds_with_event_tracking): """Events are dispatched even when block=False (non-blocking writes).""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): conn.execute("create table if not exists te5 (id integer primary key)") track_event(DummyEvent(actor=None, message="non-blocking")) task_id = await db.execute_write_fn(my_write, block=False) assert task_id is not None # Give the background task time to complete import asyncio for _ in range(50): if ds._tracked_events: break await asyncio.sleep(0.01) assert len(ds._tracked_events) == 1 assert ds._tracked_events[0].message == "non-blocking" # --- Tests for RenameTableEvent detection --- @pytest.fixture def ds_for_rename(tmp_path): """Datasette instance that records tracked events for rename detection tests.""" from datasette.events import RenameTableEvent db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) ds._tracked_events = [] ds.event_classes = (RenameTableEvent,) async def recording_track_event(event): ds._tracked_events.append(event) ds.track_event = recording_track_event return ds @pytest.mark.asyncio async def test_rename_table_fires_event(ds_for_rename): """Renaming a table via ALTER TABLE fires a RenameTableEvent.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table old_name (id integer primary key)") def rename(conn): conn.execute("alter table old_name rename to new_name") await db.execute_write_fn(rename) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 1 assert rename_events[0].old_table == "old_name" assert rename_events[0].new_table == "new_name" assert rename_events[0].database == "test" @pytest.mark.asyncio async def test_no_rename_event_for_regular_writes(ds_for_rename): """Regular writes (CREATE, INSERT) do not fire RenameTableEvent.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table t (id integer primary key)") await db.execute_write_fn(lambda conn: conn.execute("insert into t values (1)")) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 0 @pytest.mark.asyncio async def test_no_rename_event_on_rollback(ds_for_rename): """RenameTableEvent is not fired if the write raises an exception.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table rollback_test (id integer primary key)") def rename_then_fail(conn): conn.execute("alter table rollback_test rename to renamed") raise ValueError("deliberate error") with pytest.raises(ValueError, match="deliberate"): await db.execute_write_fn(rename_then_fail) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 0 @pytest.mark.asyncio async def test_multiple_renames_in_one_write(ds_for_rename): """Multiple renames in a single write fire multiple RenameTableEvents.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table alpha (id integer primary key)") await db.execute_write("create table beta (id integer primary key)") def rename_both(conn): conn.execute("alter table alpha rename to alpha2") conn.execute("alter table beta rename to beta2") await db.execute_write_fn(rename_both) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 2 names = {(e.old_table, e.new_table) for e in rename_events} assert names == {("alpha", "alpha2"), ("beta", "beta2")}