diff --git a/datasette/database.py b/datasette/database.py index 8e4ee2b6..1e6f9032 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -130,25 +130,25 @@ class Database: for connection in self._all_file_connections: connection.close() - async def execute_write(self, sql, params=None, block=True): + async def execute_write(self, sql, params=None, block=True, request=None): def _inner(conn): return conn.execute(sql, params or []) with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_write_fn(_inner, block=block) + results = await self.execute_write_fn(_inner, block=block, request=request) return results - async def execute_write_script(self, sql, block=True): + async def execute_write_script(self, sql, block=True, request=None): def _inner(conn): return conn.executescript(sql) with trace("sql", database=self.name, sql=sql.strip(), executescript=True): results = await self.execute_write_fn( - _inner, block=block, transaction=False + _inner, block=block, transaction=False, request=request ) return results - async def execute_write_many(self, sql, params_seq, block=True): + async def execute_write_many(self, sql, params_seq, block=True, request=None): def _inner(conn): count = 0 @@ -163,7 +163,9 @@ class Database: with trace( "sql", database=self.name, sql=sql.strip(), executemany=True ) as kwargs: - results, count = await self.execute_write_fn(_inner, block=block) + results, count = await self.execute_write_fn( + _inner, block=block, request=request + ) kwargs["count"] = count return results @@ -187,7 +189,8 @@ class Database: # Threaded mode - send to write thread return await self._send_to_write_thread(fn, isolated_connection=True) - async def execute_write_fn(self, fn, block=True, transaction=True): + async def execute_write_fn(self, fn, block=True, transaction=True, request=None): + fn = self._wrap_fn_with_hooks(fn, request, transaction) if self.ds.executor is None: # non-threaded mode if self._write_connection is None: @@ -203,6 +206,25 @@ class Database: fn, block=block, transaction=transaction ) + def _wrap_fn_with_hooks(self, fn, request, transaction): + from .plugins import pm + + wrappers = pm.hook.write_wrapper( + datasette=self.ds, + database=self.name, + request=request, + transaction=transaction, + ) + wrappers = [w for w in wrappers if w is not None] + if not wrappers: + return fn + # Build the wrapped fn by nesting context manager generators. + # The first wrapper returned by pluggy is outermost. + original_fn = fn + for wrapper_factory in reversed(wrappers): + original_fn = _apply_write_wrapper(original_fn, wrapper_factory) + return original_fn + async def _send_to_write_thread( self, fn, block=True, isolated_connection=False, transaction=True ): @@ -680,6 +702,47 @@ class Database: return f"" +def _apply_write_wrapper(fn, wrapper_factory): + """Apply a single write_wrapper context manager around fn. + + ``wrapper_factory`` is a callable that takes ``(conn)`` and returns a + generator that yields exactly once. Code before the yield runs before + ``fn(conn)``, code after the yield runs after. The result of + ``fn(conn)`` is sent into the generator via ``.send()``, and any + exception raised by ``fn(conn)`` is thrown via ``.throw()``. + """ + + def wrapped(conn): + gen = wrapper_factory(conn) + # Advance to the yield point (run "before" code) + try: + next(gen) + except StopIteration: + # Generator didn't yield — just run fn unchanged + return fn(conn) + + # Execute the actual write + try: + result = fn(conn) + except Exception: + # Throw exception into generator so it can handle it + try: + gen.throw(*sys.exc_info()) + except StopIteration: + pass + # Re-raise the original exception + raise + else: + # Send the result back through the yield + try: + gen.send(result) + except StopIteration: + pass + return result + + return wrapped + + class WriteTask: __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction") diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 3f6a1425..b993fb61 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -220,3 +220,25 @@ def top_query(datasette, request, database, sql): @hookspec def top_canned_query(datasette, request, database, query_name): """HTML to include at the top of the canned query page""" + + +@hookspec +def write_wrapper(datasette, database, request, transaction): + """Called when a write function is about to execute. + + Return a generator function that accepts a ``conn`` argument. + The generator should ``yield`` exactly once: code before the + ``yield`` runs before the write, code after the ``yield`` runs + after the write completes. The result of the write is sent + back through the ``yield``, so you can capture it with + ``result = yield``. + + If the write raises an exception, it is thrown into the generator + so you can handle it with a try/except around the ``yield``. + + ``request`` may be ``None`` for writes not originating from an + HTTP request. ``transaction`` is ``True`` if the write will + be wrapped in a transaction. + + Return ``None`` to skip wrapping. + """ diff --git a/datasette/views/database.py b/datasette/views/database.py index 51c752a0..e5f2cf16 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -466,7 +466,9 @@ class QueryView(View): ok = None redirect_url = None try: - cursor = await db.execute_write(canned_query["sql"], params_for_query) + cursor = await db.execute_write( + canned_query["sql"], params_for_query, request=request + ) # success message can come from on_success_message or on_success_message_sql message = None message_type = datasette.INFO @@ -1119,7 +1121,7 @@ class TableCreateView(BaseView): return table.schema try: - schema = await db.execute_write_fn(create_table) + schema = await db.execute_write_fn(create_table, request=request) except Exception as e: return _error([str(e)]) diff --git a/datasette/views/row.py b/datasette/views/row.py index 718ee00c..ff0a3594 100644 --- a/datasette/views/row.py +++ b/datasette/views/row.py @@ -245,7 +245,7 @@ class RowDeleteView(BaseView): sqlite_utils.Database(conn)[resolved.table].delete(resolved.pk_values) try: - await resolved.db.execute_write_fn(delete_row) + await resolved.db.execute_write_fn(delete_row, request=request) except Exception as e: return _error([str(e)], 500) @@ -305,7 +305,7 @@ class RowUpdateView(BaseView): ) try: - await resolved.db.execute_write_fn(update_row) + await resolved.db.execute_write_fn(update_row, request=request) except Exception as e: return _error([str(e)], 400) diff --git a/datasette/views/table.py b/datasette/views/table.py index b07b62ae..d4dbc194 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -550,7 +550,7 @@ class TableInsertView(BaseView): method_all(rows, **kwargs) try: - rows = await db.execute_write_fn(insert_or_upsert_rows) + rows = await db.execute_write_fn(insert_or_upsert_rows, request=request) except Exception as e: return _error([str(e)]) result = {"ok": True} @@ -670,7 +670,7 @@ class TableDropView(BaseView): def drop_table(conn): sqlite_utils.Database(conn)[table_name].drop() - await db.execute_write_fn(drop_table) + await db.execute_write_fn(drop_table, request=request) await self.ds.track_event( DropTableEvent( actor=request.actor, database=database_name, table=table_name diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index ad4a70f8..468b0ade 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -61,6 +61,92 @@ arguments and can be called like this:: Examples: `datasette-jellyfish `__, `datasette-jq `__, `datasette-haversine `__, `datasette-rure `__ +.. _plugin_hook_write_wrapper: + +write_wrapper(datasette, database, request, transaction) +-------------------------------------------------------- + +``datasette`` - :ref:`internals_datasette` + You can use this to access plugin configuration options via ``datasette.plugin_config(your_plugin_name)``. + +``database`` - string + The name of the database being written to. + +``request`` - :ref:`internals_request` or ``None`` + The HTTP request that triggered this write, if available. This will be ``None`` for writes that do not originate from an HTTP request (e.g. writes triggered by plugins during startup). + +``transaction`` - bool + ``True`` if the write will be wrapped in a database transaction. + +Return a generator function that accepts a ``conn`` argument (a SQLite connection object). The generator should ``yield`` exactly once. Code before the ``yield`` runs before the write function executes; code after the ``yield`` runs after it completes. + +The result of the write function is sent back through the ``yield``, so you can capture it with ``result = yield``. + +If the write function raises an exception, it is thrown into the generator so you can handle it with a ``try`` / ``except`` around the ``yield``. + +Return ``None`` to skip wrapping for this particular write. + +This example logs every write operation: + +.. code-block:: python + + from datasette import hookimpl + + + @hookimpl + def write_wrapper(datasette, database, request): + def wrapper(conn): + print(f"Before write to {database}") + result = yield + print(f"After write to {database}") + + return wrapper + +This more advanced example uses the SQLite authorizer callback to block writes to a specific table for non-admin users: + +.. code-block:: python + + import sqlite3 + from datasette import hookimpl + + WRITE_ACTIONS = ( + sqlite3.SQLITE_INSERT, + sqlite3.SQLITE_UPDATE, + sqlite3.SQLITE_DELETE, + ) + + + @hookimpl + def write_wrapper(datasette, database, request): + actor = None + if request: + actor = request.actor + if actor and actor.get("id") == "admin": + return None + + def wrapper(conn): + def authorizer( + action, arg1, arg2, db_name, trigger + ): + if ( + action in WRITE_ACTIONS + and arg1 == "protected_table" + ): + return sqlite3.SQLITE_DENY + return sqlite3.SQLITE_OK + + conn.set_authorizer(authorizer) + try: + yield + finally: + conn.set_authorizer(None) + + return wrapper + +The ``conn`` object passed to the generator is the same connection that the write function will use. Because the generator and the write function execute together in a single call on the write thread, any state you set on the connection (authorizers, pragmas, temporary tables) is visible to the write and can be cleaned up afterwards. + +When multiple plugins implement ``write_wrapper``, they are nested following pluggy's default calling convention. + .. _plugin_hook_prepare_jinja2_environment: prepare_jinja2_environment(env, datasette) @@ -2249,3 +2335,4 @@ The plugin can then call ``datasette.track_event(...)`` to send a ``ban-user`` e await datasette.track_event( BanUserEvent(user={"id": 1, "username": "cleverbot"}) ) + diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 6c23b3ef..7c2180e8 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1524,6 +1524,36 @@ async def test_hook_register_events(): assert any(k.__name__ == "OneEvent" for k in datasette.event_classes) +@pytest.mark.asyncio +async def test_hook_write_wrapper(): + datasette = Datasette(memory=True) + log = [] + + class WrapWritePlugin: + __name__ = "WrapWritePlugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + if database != "_memory": + return None + + def wrapper(conn): + log.append("before") + yield + log.append("after") + + return wrapper + + pm.register(WrapWritePlugin(), name="WrapWritePluginTest") + try: + db = datasette.get_database("_memory") + await db.execute_write("create table t (id integer primary key)") + assert log == ["before", "after"] + finally: + pm.unregister(name="WrapWritePluginTest") + + @pytest.mark.asyncio async def test_hook_register_actions_view_collection(): datasette = Datasette(memory=True, plugins_dir=PLUGINS_DIR) diff --git a/tests/test_write_wrapper.py b/tests/test_write_wrapper.py new file mode 100644 index 00000000..e05a2a9f --- /dev/null +++ b/tests/test_write_wrapper.py @@ -0,0 +1,387 @@ +""" +Tests for the write_wrapper plugin hook. +""" + +from datasette.app import Datasette +from datasette.hookspecs import hookimpl +from datasette.plugins import pm +import pytest +import time + + +@pytest.fixture +def datasette(tmp_path): + db_path = str(tmp_path / "test.db") + ds = Datasette([db_path]) + return ds + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "use_execute_write", + (False, True), + ids=["execute_write_fn", "execute_write"], +) +async def test_write_wrapper_before_and_after(datasette, use_execute_write): + """Test that code before and after yield both execute.""" + log = [] + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + log.append("before") + yield + log.append("after") + + return wrapper + + pm.register(Plugin(), name="test_before_after") + try: + db = datasette.get_database("test") + if use_execute_write: + await db.execute_write( + "create table if not exists t (id integer primary key)" + ) + else: + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t (id integer primary key)" + ) + ) + assert log == ["before", "after"] + finally: + pm.unregister(name="test_before_after") + + +@pytest.mark.asyncio +async def test_write_wrapper_receives_result_via_yield(datasette): + """Test that the result of fn(conn) is sent back through yield.""" + captured = {} + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + result = yield + captured["result"] = result + + return wrapper + + pm.register(Plugin(), name="test_result") + try: + db = datasette.get_database("test") + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t2 (id integer primary key)" + ) + ) + assert "result" in captured + # Should be a sqlite3 Cursor + assert captured["result"] is not None + finally: + pm.unregister(name="test_result") + + +@pytest.mark.asyncio +async def test_write_wrapper_exception_thrown_into_generator(datasette): + """Test that exceptions from fn(conn) are thrown into the generator.""" + caught = {} + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + try: + yield + except Exception as e: + caught["error"] = e + + return wrapper + + pm.register(Plugin(), name="test_exception") + try: + db = datasette.get_database("test") + with pytest.raises(Exception, match="deliberate"): + await db.execute_write_fn( + lambda conn: (_ for _ in ()).throw(Exception("deliberate")) + ) + assert "error" in caught + assert str(caught["error"]) == "deliberate" + finally: + pm.unregister(name="test_exception") + + +@pytest.mark.asyncio +async def test_write_wrapper_conn_is_usable(datasette): + """Test that the conn passed to the wrapper can execute SQL.""" + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + conn.execute("create table if not exists hook_log (msg text)") + conn.execute("insert into hook_log values ('before')") + yield + conn.execute("insert into hook_log values ('after')") + + return wrapper + + pm.register(Plugin(), name="test_conn") + try: + db = datasette.get_database("test") + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t3 (id integer primary key)" + ) + ) + result = await db.execute("select msg from hook_log order by rowid") + messages = [row[0] for row in result.rows] + assert messages == ["before", "after"] + finally: + pm.unregister(name="test_conn") + + +@pytest.mark.asyncio +async def test_write_wrapper_multiple_plugins_nest(datasette): + """Test that multiple write_wrapper plugins nest correctly.""" + log = [] + + class PluginA: + __name__ = "PluginA" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + log.append("A-before") + yield + log.append("A-after") + + return wrapper + + class PluginB: + __name__ = "PluginB" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + def wrapper(conn): + log.append("B-before") + yield + log.append("B-after") + + return wrapper + + pm.register(PluginA(), name="PluginA") + pm.register(PluginB(), name="PluginB") + try: + db = datasette.get_database("test") + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t4 (id integer primary key)" + ) + ) + assert set(log) == {"A-before", "A-after", "B-before", "B-after"} + # Verify proper nesting: each plugin's before/after should be + # symmetric around the write + a_before = log.index("A-before") + a_after = log.index("A-after") + b_before = log.index("B-before") + b_after = log.index("B-after") + if a_before < b_before: + assert a_after > b_after, "A is outer so A-after should come after B-after" + else: + assert b_after > a_after, "B is outer so B-after should come after A-after" + finally: + pm.unregister(name="PluginA") + pm.unregister(name="PluginB") + + +@pytest.mark.asyncio +async def test_write_wrapper_return_none_skips(datasette): + """Test that returning None from write_wrapper means no wrapping.""" + log = [] + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + log.append("hook-called") + return None + + pm.register(Plugin(), name="test_skip") + try: + db = datasette.get_database("test") + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t5 (id integer primary key)" + ) + ) + assert log == ["hook-called"] + finally: + pm.unregister(name="test_skip") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "request_value,transaction_value,expected_request,expected_transaction", + ( + ("fake-request", True, "fake-request", True), + (None, True, None, True), + (None, False, None, False), + ), + ids=["with-request", "request-none-by-default", "transaction-false"], +) +async def test_write_wrapper_hook_parameters( + datasette, + request_value, + transaction_value, + expected_request, + expected_transaction, +): + """Test that request and transaction parameters are passed through.""" + captured = {} + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + captured["request"] = request + captured["database"] = database + captured["transaction"] = transaction + + pm.register(Plugin(), name="test_params") + try: + db = datasette.get_database("test") + kwargs = {"transaction": transaction_value} + if request_value is not None: + kwargs["request"] = request_value + await db.execute_write_fn( + lambda conn: conn.execute( + "create table if not exists t6 (id integer primary key)" + ), + **kwargs, + ) + assert captured["request"] == expected_request + assert captured["database"] == "test" + assert captured["transaction"] == expected_transaction + finally: + pm.unregister(name="test_params") + + +@pytest.mark.asyncio +async def test_write_wrapper_via_api(tmp_path): + """Test that write_wrapper fires for API write operations.""" + log = [] + + db_path = str(tmp_path / "test.db") + ds = Datasette([db_path], pdb=False) + ds.root_enabled = True + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + if database != "test": + return None + + def wrapper(conn): + log.append("before") + yield + log.append("after") + + return wrapper + + pm.register(Plugin(), name="test_api") + try: + db = ds.get_database("test") + await db.execute_write( + "create table if not exists api_test (id integer primary key, name text)" + ) + log.clear() + + token = "dstok_{}".format( + ds.sign( + {"a": "root", "token": "dstok", "t": int(time.time())}, + namespace="token", + ) + ) + response = await ds.client.post( + "/test/api_test/-/insert", + json={"row": {"name": "test"}, "return": True}, + headers={ + "Authorization": "Bearer {}".format(token), + "Content-Type": "application/json", + }, + ) + assert response.status_code == 201, response.json() + assert log == ["before", "after"] + finally: + pm.unregister(name="test_api") + + +@pytest.mark.asyncio +async def test_write_wrapper_change_group_pattern(datasette): + """Test the motivating use case: activating a change group around a write.""" + db = datasette.get_database("test") + + await db.execute_write( + "create table if not exists groups (id integer primary key, current integer)" + ) + await db.execute_write( + "create table if not exists data (id integer primary key, value text)" + ) + await db.execute_write("insert into groups (id, current) values (1, null)") + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + if request and getattr(request, "group_id", None): + group_id = request.group_id + + def wrapper(conn): + conn.execute( + "update groups set current = 1 where id = ?", [group_id] + ) + yield + conn.execute("update groups set current = null where current = 1") + + return wrapper + + pm.register(Plugin(), name="test_change_group") + try: + + class FakeRequest: + group_id = 1 + + await db.execute_write_fn( + lambda conn: conn.execute("insert into data (value) values ('test')"), + request=FakeRequest(), + ) + + result = await db.execute("select current from groups where id = 1") + assert result.rows[0][0] is None + finally: + pm.unregister(name="test_change_group")