mirror of
https://github.com/simonw/datasette.git
synced 2026-06-06 00:56:57 +02:00
write_wrapper plugin hook for intercepting write operations (#2636)
* Implement write_wrapper plugin hook for intercepting database writes Add a new `write_wrapper` plugin hook that lets plugins wrap write operations with before/after logic using a generator-based context manager pattern. The hook receives (datasette, database, request, transaction) and returns a generator function that takes a conn, yields once to let the write execute, and can run cleanup after. The write result is sent back via `generator.send()` and exceptions are thrown via `generator.throw()`, giving plugins full visibility. Also adds `request=None` parameter to execute_write, execute_write_fn, execute_write_script, and execute_write_many, and threads request through all view-layer call sites (insert, upsert, update, delete, drop, create table, canned queries). * Add documentation for wrap_write hook, fix lint issues Document the wrap_write plugin hook in plugin_hooks.rst with parameter descriptions and two examples: a simple logging wrapper and an advanced SQLite authorizer-based table protection pattern. Also fix black formatting and remove unused variable flagged by ruff. * Rename wrap_write hook to write_wrapper for consistency with asgi_wrapper * Move write_wrapper docs to just below prepare_connection * Refactor write_wrapper tests to use pytest.parametrize Consolidate duplicate test cases: merge before/after tests for execute_write_fn and execute_write into one parametrized test, and merge three parameter-passing tests into one parametrized test. Claude Code transcript: https://gisthost.github.io/?c4c12079434e69677e4aa8ac664b21b8/index.html
This commit is contained in:
parent
5873578d49
commit
80b7f987ca
8 changed files with 604 additions and 13 deletions
|
|
@ -130,25 +130,25 @@ class Database:
|
|||
for connection in self._all_file_connections:
|
||||
connection.close()
|
||||
|
||||
async def execute_write(self, sql, params=None, block=True):
|
||||
async def execute_write(self, sql, params=None, block=True, request=None):
|
||||
def _inner(conn):
|
||||
return conn.execute(sql, params or [])
|
||||
|
||||
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
||||
results = await self.execute_write_fn(_inner, block=block)
|
||||
results = await self.execute_write_fn(_inner, block=block, request=request)
|
||||
return results
|
||||
|
||||
async def execute_write_script(self, sql, block=True):
|
||||
async def execute_write_script(self, sql, block=True, request=None):
|
||||
def _inner(conn):
|
||||
return conn.executescript(sql)
|
||||
|
||||
with trace("sql", database=self.name, sql=sql.strip(), executescript=True):
|
||||
results = await self.execute_write_fn(
|
||||
_inner, block=block, transaction=False
|
||||
_inner, block=block, transaction=False, request=request
|
||||
)
|
||||
return results
|
||||
|
||||
async def execute_write_many(self, sql, params_seq, block=True):
|
||||
async def execute_write_many(self, sql, params_seq, block=True, request=None):
|
||||
def _inner(conn):
|
||||
count = 0
|
||||
|
||||
|
|
@ -163,7 +163,9 @@ class Database:
|
|||
with trace(
|
||||
"sql", database=self.name, sql=sql.strip(), executemany=True
|
||||
) as kwargs:
|
||||
results, count = await self.execute_write_fn(_inner, block=block)
|
||||
results, count = await self.execute_write_fn(
|
||||
_inner, block=block, request=request
|
||||
)
|
||||
kwargs["count"] = count
|
||||
return results
|
||||
|
||||
|
|
@ -187,7 +189,8 @@ class Database:
|
|||
# Threaded mode - send to write thread
|
||||
return await self._send_to_write_thread(fn, isolated_connection=True)
|
||||
|
||||
async def execute_write_fn(self, fn, block=True, transaction=True):
|
||||
async def execute_write_fn(self, fn, block=True, transaction=True, request=None):
|
||||
fn = self._wrap_fn_with_hooks(fn, request, transaction)
|
||||
if self.ds.executor is None:
|
||||
# non-threaded mode
|
||||
if self._write_connection is None:
|
||||
|
|
@ -203,6 +206,25 @@ class Database:
|
|||
fn, block=block, transaction=transaction
|
||||
)
|
||||
|
||||
def _wrap_fn_with_hooks(self, fn, request, transaction):
|
||||
from .plugins import pm
|
||||
|
||||
wrappers = pm.hook.write_wrapper(
|
||||
datasette=self.ds,
|
||||
database=self.name,
|
||||
request=request,
|
||||
transaction=transaction,
|
||||
)
|
||||
wrappers = [w for w in wrappers if w is not None]
|
||||
if not wrappers:
|
||||
return fn
|
||||
# Build the wrapped fn by nesting context manager generators.
|
||||
# The first wrapper returned by pluggy is outermost.
|
||||
original_fn = fn
|
||||
for wrapper_factory in reversed(wrappers):
|
||||
original_fn = _apply_write_wrapper(original_fn, wrapper_factory)
|
||||
return original_fn
|
||||
|
||||
async def _send_to_write_thread(
|
||||
self, fn, block=True, isolated_connection=False, transaction=True
|
||||
):
|
||||
|
|
@ -680,6 +702,47 @@ class Database:
|
|||
return f"<Database: {self.name}{tags_str}>"
|
||||
|
||||
|
||||
def _apply_write_wrapper(fn, wrapper_factory):
|
||||
"""Apply a single write_wrapper context manager around fn.
|
||||
|
||||
``wrapper_factory`` is a callable that takes ``(conn)`` and returns a
|
||||
generator that yields exactly once. Code before the yield runs before
|
||||
``fn(conn)``, code after the yield runs after. The result of
|
||||
``fn(conn)`` is sent into the generator via ``.send()``, and any
|
||||
exception raised by ``fn(conn)`` is thrown via ``.throw()``.
|
||||
"""
|
||||
|
||||
def wrapped(conn):
|
||||
gen = wrapper_factory(conn)
|
||||
# Advance to the yield point (run "before" code)
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
# Generator didn't yield — just run fn unchanged
|
||||
return fn(conn)
|
||||
|
||||
# Execute the actual write
|
||||
try:
|
||||
result = fn(conn)
|
||||
except Exception:
|
||||
# Throw exception into generator so it can handle it
|
||||
try:
|
||||
gen.throw(*sys.exc_info())
|
||||
except StopIteration:
|
||||
pass
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
else:
|
||||
# Send the result back through the yield
|
||||
try:
|
||||
gen.send(result)
|
||||
except StopIteration:
|
||||
pass
|
||||
return result
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class WriteTask:
|
||||
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction")
|
||||
|
||||
|
|
|
|||
|
|
@ -220,3 +220,25 @@ def top_query(datasette, request, database, sql):
|
|||
@hookspec
|
||||
def top_canned_query(datasette, request, database, query_name):
|
||||
"""HTML to include at the top of the canned query page"""
|
||||
|
||||
|
||||
@hookspec
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
"""Called when a write function is about to execute.
|
||||
|
||||
Return a generator function that accepts a ``conn`` argument.
|
||||
The generator should ``yield`` exactly once: code before the
|
||||
``yield`` runs before the write, code after the ``yield`` runs
|
||||
after the write completes. The result of the write is sent
|
||||
back through the ``yield``, so you can capture it with
|
||||
``result = yield``.
|
||||
|
||||
If the write raises an exception, it is thrown into the generator
|
||||
so you can handle it with a try/except around the ``yield``.
|
||||
|
||||
``request`` may be ``None`` for writes not originating from an
|
||||
HTTP request. ``transaction`` is ``True`` if the write will
|
||||
be wrapped in a transaction.
|
||||
|
||||
Return ``None`` to skip wrapping.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -466,7 +466,9 @@ class QueryView(View):
|
|||
ok = None
|
||||
redirect_url = None
|
||||
try:
|
||||
cursor = await db.execute_write(canned_query["sql"], params_for_query)
|
||||
cursor = await db.execute_write(
|
||||
canned_query["sql"], params_for_query, request=request
|
||||
)
|
||||
# success message can come from on_success_message or on_success_message_sql
|
||||
message = None
|
||||
message_type = datasette.INFO
|
||||
|
|
@ -1119,7 +1121,7 @@ class TableCreateView(BaseView):
|
|||
return table.schema
|
||||
|
||||
try:
|
||||
schema = await db.execute_write_fn(create_table)
|
||||
schema = await db.execute_write_fn(create_table, request=request)
|
||||
except Exception as e:
|
||||
return _error([str(e)])
|
||||
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class RowDeleteView(BaseView):
|
|||
sqlite_utils.Database(conn)[resolved.table].delete(resolved.pk_values)
|
||||
|
||||
try:
|
||||
await resolved.db.execute_write_fn(delete_row)
|
||||
await resolved.db.execute_write_fn(delete_row, request=request)
|
||||
except Exception as e:
|
||||
return _error([str(e)], 500)
|
||||
|
||||
|
|
@ -305,7 +305,7 @@ class RowUpdateView(BaseView):
|
|||
)
|
||||
|
||||
try:
|
||||
await resolved.db.execute_write_fn(update_row)
|
||||
await resolved.db.execute_write_fn(update_row, request=request)
|
||||
except Exception as e:
|
||||
return _error([str(e)], 400)
|
||||
|
||||
|
|
|
|||
|
|
@ -550,7 +550,7 @@ class TableInsertView(BaseView):
|
|||
method_all(rows, **kwargs)
|
||||
|
||||
try:
|
||||
rows = await db.execute_write_fn(insert_or_upsert_rows)
|
||||
rows = await db.execute_write_fn(insert_or_upsert_rows, request=request)
|
||||
except Exception as e:
|
||||
return _error([str(e)])
|
||||
result = {"ok": True}
|
||||
|
|
@ -670,7 +670,7 @@ class TableDropView(BaseView):
|
|||
def drop_table(conn):
|
||||
sqlite_utils.Database(conn)[table_name].drop()
|
||||
|
||||
await db.execute_write_fn(drop_table)
|
||||
await db.execute_write_fn(drop_table, request=request)
|
||||
await self.ds.track_event(
|
||||
DropTableEvent(
|
||||
actor=request.actor, database=database_name, table=table_name
|
||||
|
|
|
|||
|
|
@ -61,6 +61,92 @@ arguments and can be called like this::
|
|||
|
||||
Examples: `datasette-jellyfish <https://datasette.io/plugins/datasette-jellyfish>`__, `datasette-jq <https://datasette.io/plugins/datasette-jq>`__, `datasette-haversine <https://datasette.io/plugins/datasette-haversine>`__, `datasette-rure <https://datasette.io/plugins/datasette-rure>`__
|
||||
|
||||
.. _plugin_hook_write_wrapper:
|
||||
|
||||
write_wrapper(datasette, database, request, transaction)
|
||||
--------------------------------------------------------
|
||||
|
||||
``datasette`` - :ref:`internals_datasette`
|
||||
You can use this to access plugin configuration options via ``datasette.plugin_config(your_plugin_name)``.
|
||||
|
||||
``database`` - string
|
||||
The name of the database being written to.
|
||||
|
||||
``request`` - :ref:`internals_request` or ``None``
|
||||
The HTTP request that triggered this write, if available. This will be ``None`` for writes that do not originate from an HTTP request (e.g. writes triggered by plugins during startup).
|
||||
|
||||
``transaction`` - bool
|
||||
``True`` if the write will be wrapped in a database transaction.
|
||||
|
||||
Return a generator function that accepts a ``conn`` argument (a SQLite connection object). The generator should ``yield`` exactly once. Code before the ``yield`` runs before the write function executes; code after the ``yield`` runs after it completes.
|
||||
|
||||
The result of the write function is sent back through the ``yield``, so you can capture it with ``result = yield``.
|
||||
|
||||
If the write function raises an exception, it is thrown into the generator so you can handle it with a ``try`` / ``except`` around the ``yield``.
|
||||
|
||||
Return ``None`` to skip wrapping for this particular write.
|
||||
|
||||
This example logs every write operation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from datasette import hookimpl
|
||||
|
||||
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request):
|
||||
def wrapper(conn):
|
||||
print(f"Before write to {database}")
|
||||
result = yield
|
||||
print(f"After write to {database}")
|
||||
|
||||
return wrapper
|
||||
|
||||
This more advanced example uses the SQLite authorizer callback to block writes to a specific table for non-admin users:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sqlite3
|
||||
from datasette import hookimpl
|
||||
|
||||
WRITE_ACTIONS = (
|
||||
sqlite3.SQLITE_INSERT,
|
||||
sqlite3.SQLITE_UPDATE,
|
||||
sqlite3.SQLITE_DELETE,
|
||||
)
|
||||
|
||||
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request):
|
||||
actor = None
|
||||
if request:
|
||||
actor = request.actor
|
||||
if actor and actor.get("id") == "admin":
|
||||
return None
|
||||
|
||||
def wrapper(conn):
|
||||
def authorizer(
|
||||
action, arg1, arg2, db_name, trigger
|
||||
):
|
||||
if (
|
||||
action in WRITE_ACTIONS
|
||||
and arg1 == "protected_table"
|
||||
):
|
||||
return sqlite3.SQLITE_DENY
|
||||
return sqlite3.SQLITE_OK
|
||||
|
||||
conn.set_authorizer(authorizer)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
conn.set_authorizer(None)
|
||||
|
||||
return wrapper
|
||||
|
||||
The ``conn`` object passed to the generator is the same connection that the write function will use. Because the generator and the write function execute together in a single call on the write thread, any state you set on the connection (authorizers, pragmas, temporary tables) is visible to the write and can be cleaned up afterwards.
|
||||
|
||||
When multiple plugins implement ``write_wrapper``, they are nested following pluggy's default calling convention.
|
||||
|
||||
.. _plugin_hook_prepare_jinja2_environment:
|
||||
|
||||
prepare_jinja2_environment(env, datasette)
|
||||
|
|
@ -2249,3 +2335,4 @@ The plugin can then call ``datasette.track_event(...)`` to send a ``ban-user`` e
|
|||
await datasette.track_event(
|
||||
BanUserEvent(user={"id": 1, "username": "cleverbot"})
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1524,6 +1524,36 @@ async def test_hook_register_events():
|
|||
assert any(k.__name__ == "OneEvent" for k in datasette.event_classes)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_write_wrapper():
|
||||
datasette = Datasette(memory=True)
|
||||
log = []
|
||||
|
||||
class WrapWritePlugin:
|
||||
__name__ = "WrapWritePlugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
if database != "_memory":
|
||||
return None
|
||||
|
||||
def wrapper(conn):
|
||||
log.append("before")
|
||||
yield
|
||||
log.append("after")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(WrapWritePlugin(), name="WrapWritePluginTest")
|
||||
try:
|
||||
db = datasette.get_database("_memory")
|
||||
await db.execute_write("create table t (id integer primary key)")
|
||||
assert log == ["before", "after"]
|
||||
finally:
|
||||
pm.unregister(name="WrapWritePluginTest")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_register_actions_view_collection():
|
||||
datasette = Datasette(memory=True, plugins_dir=PLUGINS_DIR)
|
||||
|
|
|
|||
387
tests/test_write_wrapper.py
Normal file
387
tests/test_write_wrapper.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Tests for the write_wrapper plugin hook.
|
||||
"""
|
||||
|
||||
from datasette.app import Datasette
|
||||
from datasette.hookspecs import hookimpl
|
||||
from datasette.plugins import pm
|
||||
import pytest
|
||||
import time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def datasette(tmp_path):
|
||||
db_path = str(tmp_path / "test.db")
|
||||
ds = Datasette([db_path])
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"use_execute_write",
|
||||
(False, True),
|
||||
ids=["execute_write_fn", "execute_write"],
|
||||
)
|
||||
async def test_write_wrapper_before_and_after(datasette, use_execute_write):
|
||||
"""Test that code before and after yield both execute."""
|
||||
log = []
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
log.append("before")
|
||||
yield
|
||||
log.append("after")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_before_after")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
if use_execute_write:
|
||||
await db.execute_write(
|
||||
"create table if not exists t (id integer primary key)"
|
||||
)
|
||||
else:
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t (id integer primary key)"
|
||||
)
|
||||
)
|
||||
assert log == ["before", "after"]
|
||||
finally:
|
||||
pm.unregister(name="test_before_after")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_receives_result_via_yield(datasette):
|
||||
"""Test that the result of fn(conn) is sent back through yield."""
|
||||
captured = {}
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
result = yield
|
||||
captured["result"] = result
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_result")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t2 (id integer primary key)"
|
||||
)
|
||||
)
|
||||
assert "result" in captured
|
||||
# Should be a sqlite3 Cursor
|
||||
assert captured["result"] is not None
|
||||
finally:
|
||||
pm.unregister(name="test_result")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_exception_thrown_into_generator(datasette):
|
||||
"""Test that exceptions from fn(conn) are thrown into the generator."""
|
||||
caught = {}
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
caught["error"] = e
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_exception")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
with pytest.raises(Exception, match="deliberate"):
|
||||
await db.execute_write_fn(
|
||||
lambda conn: (_ for _ in ()).throw(Exception("deliberate"))
|
||||
)
|
||||
assert "error" in caught
|
||||
assert str(caught["error"]) == "deliberate"
|
||||
finally:
|
||||
pm.unregister(name="test_exception")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_conn_is_usable(datasette):
|
||||
"""Test that the conn passed to the wrapper can execute SQL."""
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
conn.execute("create table if not exists hook_log (msg text)")
|
||||
conn.execute("insert into hook_log values ('before')")
|
||||
yield
|
||||
conn.execute("insert into hook_log values ('after')")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_conn")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t3 (id integer primary key)"
|
||||
)
|
||||
)
|
||||
result = await db.execute("select msg from hook_log order by rowid")
|
||||
messages = [row[0] for row in result.rows]
|
||||
assert messages == ["before", "after"]
|
||||
finally:
|
||||
pm.unregister(name="test_conn")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_multiple_plugins_nest(datasette):
|
||||
"""Test that multiple write_wrapper plugins nest correctly."""
|
||||
log = []
|
||||
|
||||
class PluginA:
|
||||
__name__ = "PluginA"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
log.append("A-before")
|
||||
yield
|
||||
log.append("A-after")
|
||||
|
||||
return wrapper
|
||||
|
||||
class PluginB:
|
||||
__name__ = "PluginB"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
def wrapper(conn):
|
||||
log.append("B-before")
|
||||
yield
|
||||
log.append("B-after")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(PluginA(), name="PluginA")
|
||||
pm.register(PluginB(), name="PluginB")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t4 (id integer primary key)"
|
||||
)
|
||||
)
|
||||
assert set(log) == {"A-before", "A-after", "B-before", "B-after"}
|
||||
# Verify proper nesting: each plugin's before/after should be
|
||||
# symmetric around the write
|
||||
a_before = log.index("A-before")
|
||||
a_after = log.index("A-after")
|
||||
b_before = log.index("B-before")
|
||||
b_after = log.index("B-after")
|
||||
if a_before < b_before:
|
||||
assert a_after > b_after, "A is outer so A-after should come after B-after"
|
||||
else:
|
||||
assert b_after > a_after, "B is outer so B-after should come after A-after"
|
||||
finally:
|
||||
pm.unregister(name="PluginA")
|
||||
pm.unregister(name="PluginB")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_return_none_skips(datasette):
|
||||
"""Test that returning None from write_wrapper means no wrapping."""
|
||||
log = []
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
log.append("hook-called")
|
||||
return None
|
||||
|
||||
pm.register(Plugin(), name="test_skip")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t5 (id integer primary key)"
|
||||
)
|
||||
)
|
||||
assert log == ["hook-called"]
|
||||
finally:
|
||||
pm.unregister(name="test_skip")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"request_value,transaction_value,expected_request,expected_transaction",
|
||||
(
|
||||
("fake-request", True, "fake-request", True),
|
||||
(None, True, None, True),
|
||||
(None, False, None, False),
|
||||
),
|
||||
ids=["with-request", "request-none-by-default", "transaction-false"],
|
||||
)
|
||||
async def test_write_wrapper_hook_parameters(
|
||||
datasette,
|
||||
request_value,
|
||||
transaction_value,
|
||||
expected_request,
|
||||
expected_transaction,
|
||||
):
|
||||
"""Test that request and transaction parameters are passed through."""
|
||||
captured = {}
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
captured["request"] = request
|
||||
captured["database"] = database
|
||||
captured["transaction"] = transaction
|
||||
|
||||
pm.register(Plugin(), name="test_params")
|
||||
try:
|
||||
db = datasette.get_database("test")
|
||||
kwargs = {"transaction": transaction_value}
|
||||
if request_value is not None:
|
||||
kwargs["request"] = request_value
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute(
|
||||
"create table if not exists t6 (id integer primary key)"
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
assert captured["request"] == expected_request
|
||||
assert captured["database"] == "test"
|
||||
assert captured["transaction"] == expected_transaction
|
||||
finally:
|
||||
pm.unregister(name="test_params")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_via_api(tmp_path):
|
||||
"""Test that write_wrapper fires for API write operations."""
|
||||
log = []
|
||||
|
||||
db_path = str(tmp_path / "test.db")
|
||||
ds = Datasette([db_path], pdb=False)
|
||||
ds.root_enabled = True
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
if database != "test":
|
||||
return None
|
||||
|
||||
def wrapper(conn):
|
||||
log.append("before")
|
||||
yield
|
||||
log.append("after")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_api")
|
||||
try:
|
||||
db = ds.get_database("test")
|
||||
await db.execute_write(
|
||||
"create table if not exists api_test (id integer primary key, name text)"
|
||||
)
|
||||
log.clear()
|
||||
|
||||
token = "dstok_{}".format(
|
||||
ds.sign(
|
||||
{"a": "root", "token": "dstok", "t": int(time.time())},
|
||||
namespace="token",
|
||||
)
|
||||
)
|
||||
response = await ds.client.post(
|
||||
"/test/api_test/-/insert",
|
||||
json={"row": {"name": "test"}, "return": True},
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(token),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201, response.json()
|
||||
assert log == ["before", "after"]
|
||||
finally:
|
||||
pm.unregister(name="test_api")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_wrapper_change_group_pattern(datasette):
|
||||
"""Test the motivating use case: activating a change group around a write."""
|
||||
db = datasette.get_database("test")
|
||||
|
||||
await db.execute_write(
|
||||
"create table if not exists groups (id integer primary key, current integer)"
|
||||
)
|
||||
await db.execute_write(
|
||||
"create table if not exists data (id integer primary key, value text)"
|
||||
)
|
||||
await db.execute_write("insert into groups (id, current) values (1, null)")
|
||||
|
||||
class Plugin:
|
||||
__name__ = "Plugin"
|
||||
|
||||
@staticmethod
|
||||
@hookimpl
|
||||
def write_wrapper(datasette, database, request, transaction):
|
||||
if request and getattr(request, "group_id", None):
|
||||
group_id = request.group_id
|
||||
|
||||
def wrapper(conn):
|
||||
conn.execute(
|
||||
"update groups set current = 1 where id = ?", [group_id]
|
||||
)
|
||||
yield
|
||||
conn.execute("update groups set current = null where current = 1")
|
||||
|
||||
return wrapper
|
||||
|
||||
pm.register(Plugin(), name="test_change_group")
|
||||
try:
|
||||
|
||||
class FakeRequest:
|
||||
group_id = 1
|
||||
|
||||
await db.execute_write_fn(
|
||||
lambda conn: conn.execute("insert into data (value) values ('test')"),
|
||||
request=FakeRequest(),
|
||||
)
|
||||
|
||||
result = await db.execute("select current from groups where id = 1")
|
||||
assert result.rows[0][0] is None
|
||||
finally:
|
||||
pm.unregister(name="test_change_group")
|
||||
Loading…
Add table
Add a link
Reference in a new issue