diff --git a/tests/test_write_wrapper.py b/tests/test_write_wrapper.py index e05a2a9f..38e5c94e 100644 --- a/tests/test_write_wrapper.py +++ b/tests/test_write_wrapper.py @@ -6,6 +6,7 @@ from datasette.app import Datasette from datasette.hookspecs import hookimpl from datasette.plugins import pm import pytest +import sqlite3 import time @@ -385,3 +386,92 @@ async def test_write_wrapper_change_group_pattern(datasette): assert result.rows[0][0] is None finally: pm.unregister(name="test_change_group") + + +WRITE_ACTIONS = ( + sqlite3.SQLITE_INSERT, + sqlite3.SQLITE_UPDATE, + sqlite3.SQLITE_DELETE, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "actor,table,should_deny", + ( + (None, "protected_table", True), + ({"id": "regular"}, "protected_table", True), + ({"id": "admin"}, "protected_table", False), + (None, "other_table", False), + ({"id": "regular"}, "other_table", False), + ), + ids=[ + "no-actor-protected", + "regular-user-protected", + "admin-protected", + "no-actor-other", + "regular-user-other", + ], +) +async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny): + """Test the docs example that uses set_authorizer to block writes to a protected table.""" + db = datasette.get_database("test") + await db.execute_write( + "create table if not exists protected_table (id integer primary key, value text)" + ) + await db.execute_write( + "create table if not exists other_table (id integer primary key, value text)" + ) + + class Plugin: + __name__ = "Plugin" + + @staticmethod + @hookimpl + def write_wrapper(datasette, database, request, transaction): + actor = None + if request: + actor = request.actor + if actor and actor.get("id") == "admin": + return None + + def wrapper(conn): + def authorizer(action, arg1, arg2, db_name, trigger): + if action in WRITE_ACTIONS and arg1 == "protected_table": + return sqlite3.SQLITE_DENY + return sqlite3.SQLITE_OK + + conn.set_authorizer(authorizer) + try: + yield + finally: + conn.set_authorizer(None) + + return wrapper + + class FakeRequest: + def __init__(self, actor): + self.actor = actor + + pm.register(Plugin(), name="test_set_authorizer") + try: + request = FakeRequest(actor) + if should_deny: + with pytest.raises(Exception): + await db.execute_write_fn( + lambda conn: conn.execute( + f"insert into {table} (value) values ('test')" + ), + request=request, + ) + else: + await db.execute_write_fn( + lambda conn: conn.execute( + f"insert into {table} (value) values ('test')" + ), + request=request, + ) + result = await db.execute(f"select value from {table} order by rowid desc limit 1") + assert result.rows[0][0] == "test" + finally: + pm.unregister(name="test_set_authorizer")