""" Tests for the write_wrapper plugin hook. """ import asyncio from dataclasses import dataclass from datasette.app import Datasette from datasette.events import Event from datasette.hookspecs import hookimpl from datasette.plugins import pm import pytest import sqlite3 import time @dataclass class DummyEvent(Event): name = "dummy" message: str @pytest.fixture def datasette(tmp_path): db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) return ds @pytest.mark.asyncio @pytest.mark.parametrize( "use_execute_write", (False, True), ids=["execute_write_fn", "execute_write"], ) async def test_write_wrapper_before_and_after(datasette, use_execute_write): """Test that code before and after yield both execute.""" log = [] class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("before") yield log.append("after") return wrapper pm.register(Plugin(), name="test_before_after") try: db = datasette.get_database("test") if use_execute_write: await db.execute_write( "create table if not exists t (id integer primary key)" ) else: await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t (id integer primary key)" ) ) assert log == ["before", "after"] finally: pm.unregister(name="test_before_after") @pytest.mark.asyncio async def test_write_wrapper_receives_result_via_yield(datasette): """Test that the result of fn(conn) is sent back through yield.""" captured = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): result = yield captured["result"] = result return wrapper pm.register(Plugin(), name="test_result") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t2 (id integer primary key)" ) ) assert "result" in captured # Should be a sqlite3 Cursor assert captured["result"] is not None finally: pm.unregister(name="test_result") @pytest.mark.asyncio async def test_write_wrapper_exception_thrown_into_generator(datasette): """Test that exceptions from fn(conn) are thrown into the generator.""" caught = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): try: yield except Exception as e: caught["error"] = e return wrapper pm.register(Plugin(), name="test_exception") try: db = datasette.get_database("test") with pytest.raises(Exception, match="deliberate"): await db.execute_write_fn( lambda conn: (_ for _ in ()).throw(Exception("deliberate")) ) assert "error" in caught assert str(caught["error"]) == "deliberate" finally: pm.unregister(name="test_exception") @pytest.mark.asyncio async def test_write_wrapper_conn_is_usable(datasette): """Test that the conn passed to the wrapper can execute SQL.""" class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): conn.execute("create table if not exists hook_log (msg text)") conn.execute("insert into hook_log values ('before')") yield conn.execute("insert into hook_log values ('after')") return wrapper pm.register(Plugin(), name="test_conn") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t3 (id integer primary key)" ) ) result = await db.execute("select msg from hook_log order by rowid") messages = [row[0] for row in result.rows] assert messages == ["before", "after"] finally: pm.unregister(name="test_conn") @pytest.mark.asyncio async def test_write_wrapper_multiple_plugins_nest(datasette): """Test that multiple write_wrapper plugins nest correctly.""" log = [] class PluginA: __name__ = "PluginA" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("A-before") yield log.append("A-after") return wrapper class PluginB: __name__ = "PluginB" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn): log.append("B-before") yield log.append("B-after") return wrapper pm.register(PluginA(), name="PluginA") pm.register(PluginB(), name="PluginB") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t4 (id integer primary key)" ) ) assert set(log) == {"A-before", "A-after", "B-before", "B-after"} # Verify proper nesting: each plugin's before/after should be # symmetric around the write a_before = log.index("A-before") a_after = log.index("A-after") b_before = log.index("B-before") b_after = log.index("B-after") if a_before < b_before: assert a_after > b_after, "A is outer so A-after should come after B-after" else: assert b_after > a_after, "B is outer so B-after should come after A-after" finally: pm.unregister(name="PluginA") pm.unregister(name="PluginB") @pytest.mark.asyncio async def test_write_wrapper_return_none_skips(datasette): """Test that returning None from write_wrapper means no wrapping.""" log = [] class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): log.append("hook-called") return None pm.register(Plugin(), name="test_skip") try: db = datasette.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t5 (id integer primary key)" ) ) assert log == ["hook-called"] finally: pm.unregister(name="test_skip") @pytest.mark.asyncio @pytest.mark.parametrize( "request_value,transaction_value,expected_request,expected_transaction", ( ("fake-request", True, "fake-request", True), (None, True, None, True), (None, False, None, False), ), ids=["with-request", "request-none-by-default", "transaction-false"], ) async def test_write_wrapper_hook_parameters( datasette, request_value, transaction_value, expected_request, expected_transaction, ): """Test that request and transaction parameters are passed through.""" captured = {} class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): captured["request"] = request captured["database"] = database captured["transaction"] = transaction pm.register(Plugin(), name="test_params") try: db = datasette.get_database("test") kwargs = {"transaction": transaction_value} if request_value is not None: kwargs["request"] = request_value await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists t6 (id integer primary key)" ), **kwargs, ) assert captured["request"] == expected_request assert captured["database"] == "test" assert captured["transaction"] == expected_transaction finally: pm.unregister(name="test_params") @pytest.mark.asyncio async def test_write_wrapper_via_api(tmp_path): """Test that write_wrapper fires for API write operations.""" log = [] db_path = str(tmp_path / "test.db") ds = Datasette([db_path], pdb=False) ds.root_enabled = True class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): if database != "test": return None def wrapper(conn): log.append("before") yield log.append("after") return wrapper pm.register(Plugin(), name="test_api") try: db = ds.get_database("test") await db.execute_write( "create table if not exists api_test (id integer primary key, name text)" ) log.clear() token = "dstok_{}".format( ds.sign( {"a": "root", "token": "dstok", "t": int(time.time())}, namespace="token", ) ) response = await ds.client.post( "/test/api_test/-/insert", json={"row": {"name": "test"}, "return": True}, headers={ "Authorization": "Bearer {}".format(token), "Content-Type": "application/json", }, ) assert response.status_code == 201, response.json() assert log == ["before", "after"] finally: pm.unregister(name="test_api") @pytest.mark.asyncio async def test_write_wrapper_change_group_pattern(datasette): """Test the motivating use case: activating a change group around a write.""" db = datasette.get_database("test") await db.execute_write( "create table if not exists groups (id integer primary key, current integer)" ) await db.execute_write( "create table if not exists data (id integer primary key, value text)" ) await db.execute_write("insert into groups (id, current) values (1, null)") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): if request and getattr(request, "group_id", None): group_id = request.group_id def wrapper(conn): conn.execute( "update groups set current = 1 where id = ?", [group_id] ) yield conn.execute("update groups set current = null where current = 1") return wrapper pm.register(Plugin(), name="test_change_group") try: class FakeRequest: group_id = 1 await db.execute_write_fn( lambda conn: conn.execute("insert into data (value) values ('test')"), request=FakeRequest(), ) result = await db.execute("select current from groups where id = 1") assert result.rows[0][0] is None finally: pm.unregister(name="test_change_group") WRITE_ACTIONS = ( sqlite3.SQLITE_INSERT, sqlite3.SQLITE_UPDATE, sqlite3.SQLITE_DELETE, ) @pytest.mark.asyncio @pytest.mark.parametrize( "actor,table,should_deny", ( (None, "protected_table", True), ({"id": "regular"}, "protected_table", True), ({"id": "admin"}, "protected_table", False), (None, "other_table", False), ({"id": "regular"}, "other_table", False), ), ids=[ "no-actor-protected", "regular-user-protected", "admin-protected", "no-actor-other", "regular-user-other", ], ) async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny): """Test the docs example that uses set_authorizer to block writes to a protected table.""" db = datasette.get_database("test") await db.execute_write( "create table if not exists protected_table (id integer primary key, value text)" ) await db.execute_write( "create table if not exists other_table (id integer primary key, value text)" ) class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): actor = None if request: actor = request.actor if actor and actor.get("id") == "admin": return None def wrapper(conn): def authorizer(action, arg1, arg2, db_name, trigger): if action in WRITE_ACTIONS and arg1 == "protected_table": return sqlite3.SQLITE_DENY return sqlite3.SQLITE_OK conn.set_authorizer(authorizer) try: yield finally: conn.set_authorizer(lambda *args: sqlite3.SQLITE_OK) return wrapper class FakeRequest: def __init__(self, actor): self.actor = actor pm.register(Plugin(), name="test_set_authorizer") try: request = FakeRequest(actor) if should_deny: with pytest.raises(Exception): await db.execute_write_fn( lambda conn: conn.execute( f"insert into {table} (value) values ('test')" ), request=request, ) else: await db.execute_write_fn( lambda conn: conn.execute( f"insert into {table} (value) values ('test')" ), request=request, ) result = await db.execute( f"select value from {table} order by rowid desc limit 1" ) assert result.rows[0][0] == "test" finally: pm.unregister(name="test_set_authorizer") # --- Tests for track_event callback --- @pytest.fixture def ds_with_event_tracking(tmp_path): """Datasette instance that records tracked events and registers DummyEvent.""" db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) ds._tracked_events = [] # Set event_classes directly to avoid needing invoke_startup ds.event_classes = (DummyEvent,) async def recording_track_event(event): ds._tracked_events.append(event) ds.track_event = recording_track_event yield ds ds.close() @pytest.mark.asyncio async def test_track_event_in_write_fn(ds_with_event_tracking): """fn(conn, track_event) can queue events that are dispatched after commit.""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): conn.execute("create table if not exists te1 (id integer primary key)") track_event(DummyEvent(actor=None, message="hello")) await db.execute_write_fn(my_write) assert len(ds._tracked_events) == 1 assert ds._tracked_events[0].message == "hello" @pytest.mark.asyncio async def test_track_event_discarded_on_exception(ds_with_event_tracking): """Events are discarded if the write fn raises an exception.""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): track_event(DummyEvent(actor=None, message="should not fire")) raise ValueError("deliberate error") with pytest.raises(ValueError, match="deliberate"): await db.execute_write_fn(my_write) assert len(ds._tracked_events) == 0 @pytest.mark.asyncio async def test_track_event_existing_fn_signature_still_works(ds_with_event_tracking): """Existing fn(conn) signatures continue to work without track_event.""" ds = ds_with_event_tracking db = ds.get_database("test") await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists te2 (id integer primary key)" ) ) # No events, no errors assert len(ds._tracked_events) == 0 @pytest.mark.asyncio async def test_track_event_in_write_wrapper(ds_with_event_tracking): """write_wrapper generator with (conn, track_event) can queue events.""" ds = ds_with_event_tracking db = ds.get_database("test") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn, track_event): track_event(DummyEvent(actor=None, message="from wrapper before")) yield track_event(DummyEvent(actor=None, message="from wrapper after")) return wrapper pm.register(Plugin(), name="test_track_wrapper") try: await db.execute_write_fn( lambda conn: conn.execute( "create table if not exists te3 (id integer primary key)" ) ) assert len(ds._tracked_events) == 2 assert ds._tracked_events[0].message == "from wrapper before" assert ds._tracked_events[1].message == "from wrapper after" finally: pm.unregister(name="test_track_wrapper") @pytest.mark.asyncio async def test_track_event_shared_between_fn_and_wrapper(ds_with_event_tracking): """Both fn and wrapper can queue events, all dispatched in order.""" ds = ds_with_event_tracking db = ds.get_database("test") class Plugin: __name__ = "Plugin" @staticmethod @hookimpl def write_wrapper(datasette, database, request, transaction): def wrapper(conn, track_event): track_event(DummyEvent(actor=None, message="wrapper-before")) yield track_event(DummyEvent(actor=None, message="wrapper-after")) return wrapper pm.register(Plugin(), name="test_track_shared") try: def my_write(conn, track_event): conn.execute("create table if not exists te4 (id integer primary key)") track_event(DummyEvent(actor=None, message="from-fn")) await db.execute_write_fn(my_write) messages = [e.message for e in ds._tracked_events] assert messages == ["wrapper-before", "from-fn", "wrapper-after"] finally: pm.unregister(name="test_track_shared") @pytest.mark.asyncio async def test_track_event_with_block_false(ds_with_event_tracking): """Events are dispatched even when block=False (non-blocking writes).""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): conn.execute("create table if not exists te5 (id integer primary key)") track_event(DummyEvent(actor=None, message="non-blocking")) task_id = await db.execute_write_fn(my_write, block=False) assert task_id is not None # Give the background task time to complete for _ in range(50): if ds._tracked_events: break await asyncio.sleep(0.01) assert len(ds._tracked_events) == 1 assert ds._tracked_events[0].message == "non-blocking" @pytest.mark.asyncio async def test_track_event_with_block_false_discarded_on_exception( ds_with_event_tracking, ): """Events queued by a non-blocking write are discarded if the write fails.""" ds = ds_with_event_tracking db = ds.get_database("test") def my_write(conn, track_event): track_event(DummyEvent(actor=None, message="should not fire")) raise ValueError("deliberate error") task_id = await db.execute_write_fn(my_write, block=False) assert task_id is not None # A following blocking write proves the failed non-blocking task has # completed; one more loop turn lets its event-dispatch task observe the # exception and exit. await db.execute_write_fn(lambda conn: conn.execute("select 1")) await asyncio.sleep(0) assert ds._tracked_events == [] # --- Tests for RenameTableEvent detection --- @pytest.fixture def ds_for_rename(tmp_path): """Datasette instance that records tracked events for rename detection tests.""" from datasette.events import RenameTableEvent db_path = str(tmp_path / "test.db") ds = Datasette([db_path]) ds._tracked_events = [] ds.event_classes = (RenameTableEvent,) async def recording_track_event(event): ds._tracked_events.append(event) ds.track_event = recording_track_event return ds @pytest.mark.asyncio async def test_rename_table_fires_event(ds_for_rename): """Renaming a table via ALTER TABLE fires a RenameTableEvent.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table old_name (id integer primary key)") def rename(conn): conn.execute("alter table old_name rename to new_name") await db.execute_write_fn(rename) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 1 assert rename_events[0].old_table == "old_name" assert rename_events[0].new_table == "new_name" assert rename_events[0].database == "test" @pytest.mark.asyncio async def test_no_rename_event_for_regular_writes(ds_for_rename): """Regular writes (CREATE, INSERT) do not fire RenameTableEvent.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table t (id integer primary key)") await db.execute_write_fn(lambda conn: conn.execute("insert into t values (1)")) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 0 @pytest.mark.asyncio async def test_no_rename_event_on_rollback(ds_for_rename): """RenameTableEvent is not fired if the write raises an exception.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table rollback_test (id integer primary key)") def rename_then_fail(conn): conn.execute("alter table rollback_test rename to renamed") raise ValueError("deliberate error") with pytest.raises(ValueError, match="deliberate"): await db.execute_write_fn(rename_then_fail) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 0 @pytest.mark.asyncio async def test_multiple_renames_in_one_write(ds_for_rename): """Multiple renames in a single write fire multiple RenameTableEvents.""" from datasette.events import RenameTableEvent ds = ds_for_rename db = ds.get_database("test") await db.execute_write("create table alpha (id integer primary key)") await db.execute_write("create table beta (id integer primary key)") def rename_both(conn): conn.execute("alter table alpha rename to alpha2") conn.execute("alter table beta rename to beta2") await db.execute_write_fn(rename_both) rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] assert len(rename_events) == 2 names = {(e.old_table, e.new_table) for e in rename_events} assert names == {("alpha", "alpha2"), ("beta", "beta2")}