mirror of
https://github.com/simonw/datasette.git
synced 2026-05-27 20:36:17 +02:00
Closes #1752 AI generated patch explanation: https://gisthost.github.io/?e2b8d9c7666e988b5c003ff5e5ef3098
768 lines
23 KiB
Python
768 lines
23 KiB
Python
"""
|
|
Tests for the write_wrapper plugin hook.
|
|
"""
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from datasette.app import Datasette
|
|
from datasette.events import Event
|
|
from datasette.hookspecs import hookimpl
|
|
from datasette.plugins import pm
|
|
import pytest
|
|
import sqlite3
|
|
import time
|
|
|
|
|
|
@dataclass
|
|
class DummyEvent(Event):
|
|
name = "dummy"
|
|
message: str
|
|
|
|
|
|
@pytest.fixture
|
|
def datasette(tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
ds = Datasette([db_path])
|
|
return ds
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"use_execute_write",
|
|
(False, True),
|
|
ids=["execute_write_fn", "execute_write"],
|
|
)
|
|
async def test_write_wrapper_before_and_after(datasette, use_execute_write):
|
|
"""Test that code before and after yield both execute."""
|
|
log = []
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
log.append("before")
|
|
yield
|
|
log.append("after")
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_before_after")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
if use_execute_write:
|
|
await db.execute_write(
|
|
"create table if not exists t (id integer primary key)"
|
|
)
|
|
else:
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t (id integer primary key)"
|
|
)
|
|
)
|
|
assert log == ["before", "after"]
|
|
finally:
|
|
pm.unregister(name="test_before_after")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_receives_result_via_yield(datasette):
|
|
"""Test that the result of fn(conn) is sent back through yield."""
|
|
captured = {}
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
result = yield
|
|
captured["result"] = result
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_result")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t2 (id integer primary key)"
|
|
)
|
|
)
|
|
assert "result" in captured
|
|
# Should be a sqlite3 Cursor
|
|
assert captured["result"] is not None
|
|
finally:
|
|
pm.unregister(name="test_result")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_exception_thrown_into_generator(datasette):
|
|
"""Test that exceptions from fn(conn) are thrown into the generator."""
|
|
caught = {}
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
caught["error"] = e
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_exception")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
with pytest.raises(Exception, match="deliberate"):
|
|
await db.execute_write_fn(
|
|
lambda conn: (_ for _ in ()).throw(Exception("deliberate"))
|
|
)
|
|
assert "error" in caught
|
|
assert str(caught["error"]) == "deliberate"
|
|
finally:
|
|
pm.unregister(name="test_exception")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_conn_is_usable(datasette):
|
|
"""Test that the conn passed to the wrapper can execute SQL."""
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
conn.execute("create table if not exists hook_log (msg text)")
|
|
conn.execute("insert into hook_log values ('before')")
|
|
yield
|
|
conn.execute("insert into hook_log values ('after')")
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_conn")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t3 (id integer primary key)"
|
|
)
|
|
)
|
|
result = await db.execute("select msg from hook_log order by rowid")
|
|
messages = [row[0] for row in result.rows]
|
|
assert messages == ["before", "after"]
|
|
finally:
|
|
pm.unregister(name="test_conn")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_multiple_plugins_nest(datasette):
|
|
"""Test that multiple write_wrapper plugins nest correctly."""
|
|
log = []
|
|
|
|
class PluginA:
|
|
__name__ = "PluginA"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
log.append("A-before")
|
|
yield
|
|
log.append("A-after")
|
|
|
|
return wrapper
|
|
|
|
class PluginB:
|
|
__name__ = "PluginB"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn):
|
|
log.append("B-before")
|
|
yield
|
|
log.append("B-after")
|
|
|
|
return wrapper
|
|
|
|
pm.register(PluginA(), name="PluginA")
|
|
pm.register(PluginB(), name="PluginB")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t4 (id integer primary key)"
|
|
)
|
|
)
|
|
assert set(log) == {"A-before", "A-after", "B-before", "B-after"}
|
|
# Verify proper nesting: each plugin's before/after should be
|
|
# symmetric around the write
|
|
a_before = log.index("A-before")
|
|
a_after = log.index("A-after")
|
|
b_before = log.index("B-before")
|
|
b_after = log.index("B-after")
|
|
if a_before < b_before:
|
|
assert a_after > b_after, "A is outer so A-after should come after B-after"
|
|
else:
|
|
assert b_after > a_after, "B is outer so B-after should come after A-after"
|
|
finally:
|
|
pm.unregister(name="PluginA")
|
|
pm.unregister(name="PluginB")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_return_none_skips(datasette):
|
|
"""Test that returning None from write_wrapper means no wrapping."""
|
|
log = []
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
log.append("hook-called")
|
|
return None
|
|
|
|
pm.register(Plugin(), name="test_skip")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t5 (id integer primary key)"
|
|
)
|
|
)
|
|
assert log == ["hook-called"]
|
|
finally:
|
|
pm.unregister(name="test_skip")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"request_value,transaction_value,expected_request,expected_transaction",
|
|
(
|
|
("fake-request", True, "fake-request", True),
|
|
(None, True, None, True),
|
|
(None, False, None, False),
|
|
),
|
|
ids=["with-request", "request-none-by-default", "transaction-false"],
|
|
)
|
|
async def test_write_wrapper_hook_parameters(
|
|
datasette,
|
|
request_value,
|
|
transaction_value,
|
|
expected_request,
|
|
expected_transaction,
|
|
):
|
|
"""Test that request and transaction parameters are passed through."""
|
|
captured = {}
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
captured["request"] = request
|
|
captured["database"] = database
|
|
captured["transaction"] = transaction
|
|
|
|
pm.register(Plugin(), name="test_params")
|
|
try:
|
|
db = datasette.get_database("test")
|
|
kwargs = {"transaction": transaction_value}
|
|
if request_value is not None:
|
|
kwargs["request"] = request_value
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists t6 (id integer primary key)"
|
|
),
|
|
**kwargs,
|
|
)
|
|
assert captured["request"] == expected_request
|
|
assert captured["database"] == "test"
|
|
assert captured["transaction"] == expected_transaction
|
|
finally:
|
|
pm.unregister(name="test_params")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_via_api(tmp_path):
|
|
"""Test that write_wrapper fires for API write operations."""
|
|
log = []
|
|
|
|
db_path = str(tmp_path / "test.db")
|
|
ds = Datasette([db_path], pdb=False)
|
|
ds.root_enabled = True
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
if database != "test":
|
|
return None
|
|
|
|
def wrapper(conn):
|
|
log.append("before")
|
|
yield
|
|
log.append("after")
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_api")
|
|
try:
|
|
db = ds.get_database("test")
|
|
await db.execute_write(
|
|
"create table if not exists api_test (id integer primary key, name text)"
|
|
)
|
|
log.clear()
|
|
|
|
token = "dstok_{}".format(
|
|
ds.sign(
|
|
{"a": "root", "token": "dstok", "t": int(time.time())},
|
|
namespace="token",
|
|
)
|
|
)
|
|
response = await ds.client.post(
|
|
"/test/api_test/-/insert",
|
|
json={"row": {"name": "test"}, "return": True},
|
|
headers={
|
|
"Authorization": "Bearer {}".format(token),
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
assert response.status_code == 201, response.json()
|
|
assert log == ["before", "after"]
|
|
finally:
|
|
pm.unregister(name="test_api")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_wrapper_change_group_pattern(datasette):
|
|
"""Test the motivating use case: activating a change group around a write."""
|
|
db = datasette.get_database("test")
|
|
|
|
await db.execute_write(
|
|
"create table if not exists groups (id integer primary key, current integer)"
|
|
)
|
|
await db.execute_write(
|
|
"create table if not exists data (id integer primary key, value text)"
|
|
)
|
|
await db.execute_write("insert into groups (id, current) values (1, null)")
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
if request and getattr(request, "group_id", None):
|
|
group_id = request.group_id
|
|
|
|
def wrapper(conn):
|
|
conn.execute(
|
|
"update groups set current = 1 where id = ?", [group_id]
|
|
)
|
|
yield
|
|
conn.execute("update groups set current = null where current = 1")
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_change_group")
|
|
try:
|
|
|
|
class FakeRequest:
|
|
group_id = 1
|
|
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute("insert into data (value) values ('test')"),
|
|
request=FakeRequest(),
|
|
)
|
|
|
|
result = await db.execute("select current from groups where id = 1")
|
|
assert result.rows[0][0] is None
|
|
finally:
|
|
pm.unregister(name="test_change_group")
|
|
|
|
|
|
WRITE_ACTIONS = (
|
|
sqlite3.SQLITE_INSERT,
|
|
sqlite3.SQLITE_UPDATE,
|
|
sqlite3.SQLITE_DELETE,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"actor,table,should_deny",
|
|
(
|
|
(None, "protected_table", True),
|
|
({"id": "regular"}, "protected_table", True),
|
|
({"id": "admin"}, "protected_table", False),
|
|
(None, "other_table", False),
|
|
({"id": "regular"}, "other_table", False),
|
|
),
|
|
ids=[
|
|
"no-actor-protected",
|
|
"regular-user-protected",
|
|
"admin-protected",
|
|
"no-actor-other",
|
|
"regular-user-other",
|
|
],
|
|
)
|
|
async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny):
|
|
"""Test the docs example that uses set_authorizer to block writes to a protected table."""
|
|
db = datasette.get_database("test")
|
|
await db.execute_write(
|
|
"create table if not exists protected_table (id integer primary key, value text)"
|
|
)
|
|
await db.execute_write(
|
|
"create table if not exists other_table (id integer primary key, value text)"
|
|
)
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
actor = None
|
|
if request:
|
|
actor = request.actor
|
|
if actor and actor.get("id") == "admin":
|
|
return None
|
|
|
|
def wrapper(conn):
|
|
def authorizer(action, arg1, arg2, db_name, trigger):
|
|
if action in WRITE_ACTIONS and arg1 == "protected_table":
|
|
return sqlite3.SQLITE_DENY
|
|
return sqlite3.SQLITE_OK
|
|
|
|
conn.set_authorizer(authorizer)
|
|
try:
|
|
yield
|
|
finally:
|
|
conn.set_authorizer(lambda *args: sqlite3.SQLITE_OK)
|
|
|
|
return wrapper
|
|
|
|
class FakeRequest:
|
|
def __init__(self, actor):
|
|
self.actor = actor
|
|
|
|
pm.register(Plugin(), name="test_set_authorizer")
|
|
try:
|
|
request = FakeRequest(actor)
|
|
if should_deny:
|
|
with pytest.raises(Exception):
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
f"insert into {table} (value) values ('test')"
|
|
),
|
|
request=request,
|
|
)
|
|
else:
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
f"insert into {table} (value) values ('test')"
|
|
),
|
|
request=request,
|
|
)
|
|
result = await db.execute(
|
|
f"select value from {table} order by rowid desc limit 1"
|
|
)
|
|
assert result.rows[0][0] == "test"
|
|
finally:
|
|
pm.unregister(name="test_set_authorizer")
|
|
|
|
|
|
# --- Tests for track_event callback ---
|
|
|
|
|
|
@pytest.fixture
|
|
def ds_with_event_tracking(tmp_path):
|
|
"""Datasette instance that records tracked events and registers DummyEvent."""
|
|
db_path = str(tmp_path / "test.db")
|
|
ds = Datasette([db_path])
|
|
ds._tracked_events = []
|
|
# Set event_classes directly to avoid needing invoke_startup
|
|
ds.event_classes = (DummyEvent,)
|
|
|
|
async def recording_track_event(event):
|
|
ds._tracked_events.append(event)
|
|
|
|
ds.track_event = recording_track_event
|
|
|
|
yield ds
|
|
ds.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_in_write_fn(ds_with_event_tracking):
|
|
"""fn(conn, track_event) can queue events that are dispatched after commit."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
def my_write(conn, track_event):
|
|
conn.execute("create table if not exists te1 (id integer primary key)")
|
|
track_event(DummyEvent(actor=None, message="hello"))
|
|
|
|
await db.execute_write_fn(my_write)
|
|
assert len(ds._tracked_events) == 1
|
|
assert ds._tracked_events[0].message == "hello"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_discarded_on_exception(ds_with_event_tracking):
|
|
"""Events are discarded if the write fn raises an exception."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
def my_write(conn, track_event):
|
|
track_event(DummyEvent(actor=None, message="should not fire"))
|
|
raise ValueError("deliberate error")
|
|
|
|
with pytest.raises(ValueError, match="deliberate"):
|
|
await db.execute_write_fn(my_write)
|
|
assert len(ds._tracked_events) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_existing_fn_signature_still_works(ds_with_event_tracking):
|
|
"""Existing fn(conn) signatures continue to work without track_event."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists te2 (id integer primary key)"
|
|
)
|
|
)
|
|
# No events, no errors
|
|
assert len(ds._tracked_events) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_in_write_wrapper(ds_with_event_tracking):
|
|
"""write_wrapper generator with (conn, track_event) can queue events."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn, track_event):
|
|
track_event(DummyEvent(actor=None, message="from wrapper before"))
|
|
yield
|
|
track_event(DummyEvent(actor=None, message="from wrapper after"))
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_track_wrapper")
|
|
try:
|
|
await db.execute_write_fn(
|
|
lambda conn: conn.execute(
|
|
"create table if not exists te3 (id integer primary key)"
|
|
)
|
|
)
|
|
assert len(ds._tracked_events) == 2
|
|
assert ds._tracked_events[0].message == "from wrapper before"
|
|
assert ds._tracked_events[1].message == "from wrapper after"
|
|
finally:
|
|
pm.unregister(name="test_track_wrapper")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_shared_between_fn_and_wrapper(ds_with_event_tracking):
|
|
"""Both fn and wrapper can queue events, all dispatched in order."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
class Plugin:
|
|
__name__ = "Plugin"
|
|
|
|
@staticmethod
|
|
@hookimpl
|
|
def write_wrapper(datasette, database, request, transaction):
|
|
def wrapper(conn, track_event):
|
|
track_event(DummyEvent(actor=None, message="wrapper-before"))
|
|
yield
|
|
track_event(DummyEvent(actor=None, message="wrapper-after"))
|
|
|
|
return wrapper
|
|
|
|
pm.register(Plugin(), name="test_track_shared")
|
|
try:
|
|
|
|
def my_write(conn, track_event):
|
|
conn.execute("create table if not exists te4 (id integer primary key)")
|
|
track_event(DummyEvent(actor=None, message="from-fn"))
|
|
|
|
await db.execute_write_fn(my_write)
|
|
messages = [e.message for e in ds._tracked_events]
|
|
assert messages == ["wrapper-before", "from-fn", "wrapper-after"]
|
|
finally:
|
|
pm.unregister(name="test_track_shared")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_with_block_false(ds_with_event_tracking):
|
|
"""Events are dispatched even when block=False (non-blocking writes)."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
def my_write(conn, track_event):
|
|
conn.execute("create table if not exists te5 (id integer primary key)")
|
|
track_event(DummyEvent(actor=None, message="non-blocking"))
|
|
|
|
task_id = await db.execute_write_fn(my_write, block=False)
|
|
assert task_id is not None
|
|
|
|
# Give the background task time to complete
|
|
for _ in range(50):
|
|
if ds._tracked_events:
|
|
break
|
|
await asyncio.sleep(0.01)
|
|
|
|
assert len(ds._tracked_events) == 1
|
|
assert ds._tracked_events[0].message == "non-blocking"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_track_event_with_block_false_discarded_on_exception(
|
|
ds_with_event_tracking,
|
|
):
|
|
"""Events queued by a non-blocking write are discarded if the write fails."""
|
|
ds = ds_with_event_tracking
|
|
db = ds.get_database("test")
|
|
|
|
def my_write(conn, track_event):
|
|
track_event(DummyEvent(actor=None, message="should not fire"))
|
|
raise ValueError("deliberate error")
|
|
|
|
task_id = await db.execute_write_fn(my_write, block=False)
|
|
assert task_id is not None
|
|
|
|
# A following blocking write proves the failed non-blocking task has
|
|
# completed; one more loop turn lets its event-dispatch task observe the
|
|
# exception and exit.
|
|
await db.execute_write_fn(lambda conn: conn.execute("select 1"))
|
|
await asyncio.sleep(0)
|
|
|
|
assert ds._tracked_events == []
|
|
|
|
|
|
# --- Tests for RenameTableEvent detection ---
|
|
|
|
|
|
@pytest.fixture
|
|
def ds_for_rename(tmp_path):
|
|
"""Datasette instance that records tracked events for rename detection tests."""
|
|
from datasette.events import RenameTableEvent
|
|
|
|
db_path = str(tmp_path / "test.db")
|
|
ds = Datasette([db_path])
|
|
ds._tracked_events = []
|
|
ds.event_classes = (RenameTableEvent,)
|
|
|
|
async def recording_track_event(event):
|
|
ds._tracked_events.append(event)
|
|
|
|
ds.track_event = recording_track_event
|
|
return ds
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rename_table_fires_event(ds_for_rename):
|
|
"""Renaming a table via ALTER TABLE fires a RenameTableEvent."""
|
|
from datasette.events import RenameTableEvent
|
|
|
|
ds = ds_for_rename
|
|
db = ds.get_database("test")
|
|
|
|
await db.execute_write("create table old_name (id integer primary key)")
|
|
|
|
def rename(conn):
|
|
conn.execute("alter table old_name rename to new_name")
|
|
|
|
await db.execute_write_fn(rename)
|
|
|
|
rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)]
|
|
assert len(rename_events) == 1
|
|
assert rename_events[0].old_table == "old_name"
|
|
assert rename_events[0].new_table == "new_name"
|
|
assert rename_events[0].database == "test"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_rename_event_for_regular_writes(ds_for_rename):
|
|
"""Regular writes (CREATE, INSERT) do not fire RenameTableEvent."""
|
|
from datasette.events import RenameTableEvent
|
|
|
|
ds = ds_for_rename
|
|
db = ds.get_database("test")
|
|
|
|
await db.execute_write("create table t (id integer primary key)")
|
|
await db.execute_write_fn(lambda conn: conn.execute("insert into t values (1)"))
|
|
|
|
rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)]
|
|
assert len(rename_events) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_rename_event_on_rollback(ds_for_rename):
|
|
"""RenameTableEvent is not fired if the write raises an exception."""
|
|
from datasette.events import RenameTableEvent
|
|
|
|
ds = ds_for_rename
|
|
db = ds.get_database("test")
|
|
|
|
await db.execute_write("create table rollback_test (id integer primary key)")
|
|
|
|
def rename_then_fail(conn):
|
|
conn.execute("alter table rollback_test rename to renamed")
|
|
raise ValueError("deliberate error")
|
|
|
|
with pytest.raises(ValueError, match="deliberate"):
|
|
await db.execute_write_fn(rename_then_fail)
|
|
|
|
rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)]
|
|
assert len(rename_events) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_renames_in_one_write(ds_for_rename):
|
|
"""Multiple renames in a single write fire multiple RenameTableEvents."""
|
|
from datasette.events import RenameTableEvent
|
|
|
|
ds = ds_for_rename
|
|
db = ds.get_database("test")
|
|
|
|
await db.execute_write("create table alpha (id integer primary key)")
|
|
await db.execute_write("create table beta (id integer primary key)")
|
|
|
|
def rename_both(conn):
|
|
conn.execute("alter table alpha rename to alpha2")
|
|
conn.execute("alter table beta rename to beta2")
|
|
|
|
await db.execute_write_fn(rename_both)
|
|
|
|
rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)]
|
|
assert len(rename_events) == 2
|
|
names = {(e.old_table, e.new_table) for e in rename_events}
|
|
assert names == {("alpha", "alpha2"), ("beta", "beta2")}
|