From 4284c74bc133ab494bf4b6dcd4a20b97b05ebb83 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 19 Dec 2023 10:51:03 -0800 Subject: [PATCH] db.execute_isolated_fn() method (#2220) Closes #2218 --- datasette/database.py | 61 ++++++++++++++++++++++++------ docs/internals.rst | 19 +++++++++- tests/test_internals_database.py | 65 ++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 12 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index cb01301e..f2c980d7 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -159,6 +159,26 @@ class Database: kwargs["count"] = count return results + async def execute_isolated_fn(self, fn): + # Open a new connection just for the duration of this function + # blocking the write queue to avoid any writes occurring during it + if self.ds.executor is None: + # non-threaded mode + isolated_connection = self.connect(write=True) + try: + result = fn(isolated_connection) + finally: + isolated_connection.close() + try: + self._all_file_connections.remove(isolated_connection) + except ValueError: + # Was probably a memory connection + pass + return result + else: + # Threaded mode - send to write thread + return await self._send_to_write_thread(fn, isolated_connection=True) + async def execute_write_fn(self, fn, block=True): if self.ds.executor is None: # non-threaded mode @@ -166,9 +186,10 @@ class Database: self._write_connection = self.connect(write=True) self.ds._prepare_connection(self._write_connection, self.name) return fn(self._write_connection) + else: + return await self._send_to_write_thread(fn, block) - # threaded mode - task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") + async def _send_to_write_thread(self, fn, block=True, isolated_connection=False): if self._write_queue is None: self._write_queue = queue.Queue() if self._write_thread is None: @@ -176,8 +197,9 @@ class Database: target=self._execute_writes, daemon=True ) self._write_thread.start() + task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") reply_queue = janus.Queue() - self._write_queue.put(WriteTask(fn, task_id, reply_queue)) + self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection)) if block: result = await reply_queue.async_q.get() if isinstance(result, Exception): @@ -202,12 +224,28 @@ class Database: if conn_exception is not None: result = conn_exception else: - try: - result = task.fn(conn) - except Exception as e: - sys.stderr.write("{}\n".format(e)) - sys.stderr.flush() - result = e + if task.isolated_connection: + isolated_connection = self.connect(write=True) + try: + result = task.fn(isolated_connection) + except Exception as e: + sys.stderr.write("{}\n".format(e)) + sys.stderr.flush() + result = e + finally: + isolated_connection.close() + try: + self._all_file_connections.remove(isolated_connection) + except ValueError: + # Was probably a memory connection + pass + else: + try: + result = task.fn(conn) + except Exception as e: + sys.stderr.write("{}\n".format(e)) + sys.stderr.flush() + result = e task.reply_queue.sync_q.put(result) async def execute_fn(self, fn): @@ -515,12 +553,13 @@ class Database: class WriteTask: - __slots__ = ("fn", "task_id", "reply_queue") + __slots__ = ("fn", "task_id", "reply_queue", "isolated_connection") - def __init__(self, fn, task_id, reply_queue): + def __init__(self, fn, task_id, reply_queue, isolated_connection): self.fn = fn self.task_id = task_id self.reply_queue = reply_queue + self.isolated_connection = isolated_connection class QueryInterrupted(Exception): diff --git a/docs/internals.rst b/docs/internals.rst index 649ca35d..d269bc7d 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1017,7 +1017,7 @@ Like ``execute_write()`` but uses the ``sqlite3`` `conn.executemany() ` but executes the provided function in an entirely isolated SQLite connection, which is opened, used and then closed again in a single call to this method. + +The :ref:`prepare_connection() ` plugin hook is not executed against this connection. + +This allows plugins to execute database operations that might conflict with how database connections are usually configured. For example, running a ``VACUUM`` operation while bypassing any restrictions placed by the `datasette-sqlite-authorizer `__ plugin. + +Plugins can also use this method to load potentially dangerous SQLite extensions, use them to perform an operation and then have them safely unloaded at the end of the call, without risk of exposing them to other connections. + +Functions run using ``execute_isolated_fn()`` share the same queue as ``execute_write_fn()``, which guarantees that no writes can be executed at the same time as the isolated function is executing. + +The return value of the function will be returned by this method. Any exceptions raised by the function will be raised out of the ``await`` line as well. + .. _database_close: db.close() diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index 647ae7bd..e0511100 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -1,6 +1,7 @@ """ Tests for the datasette.database.Database class """ +from datasette.app import Datasette from datasette.database import Database, Results, MultipleValues from datasette.utils.sqlite import sqlite3 from datasette.utils import Column @@ -519,6 +520,70 @@ async def test_execute_write_fn_connection_exception(tmpdir, app_client): app_client.ds.remove_database("immutable-db") +def table_exists(conn, name): + return bool( + conn.execute( + """ + with all_tables as ( + select name from sqlite_master where type = 'table' + union all + select name from temp.sqlite_master where type = 'table' + ) + select 1 from all_tables where name = ? + """, + (name,), + ).fetchall(), + ) + + +def table_exists_checker(name): + def inner(conn): + return table_exists(conn, name) + + return inner + + +@pytest.mark.asyncio +@pytest.mark.parametrize("disable_threads", (False, True)) +async def test_execute_isolated(db, disable_threads): + if disable_threads: + ds = Datasette(memory=True, settings={"num_sql_threads": 0}) + db = ds.add_database(Database(ds, memory_name="test_num_sql_threads_zero")) + + # Create temporary table in write + await db.execute_write( + "create temporary table created_by_write (id integer primary key)" + ) + # Should stay visible to write connection + assert await db.execute_write_fn(table_exists_checker("created_by_write")) + + def create_shared_table(conn): + conn.execute("create table shared (id integer primary key)") + # And a temporary table that should not continue to exist + conn.execute( + "create temporary table created_by_isolated (id integer primary key)" + ) + assert table_exists(conn, "created_by_isolated") + # Also confirm that created_by_write does not exist + return table_exists(conn, "created_by_write") + + # shared should not exist + assert not await db.execute_fn(table_exists_checker("shared")) + + # Create it using isolated + created_by_write_exists = await db.execute_isolated_fn(create_shared_table) + assert not created_by_write_exists + + # shared SHOULD exist now + assert await db.execute_fn(table_exists_checker("shared")) + + # created_by_isolated should not exist, even in write connection + assert not await db.execute_write_fn(table_exists_checker("created_by_isolated")) + + # ... and a second call to isolated should not see that connection either + assert not await db.execute_isolated_fn(table_exists_checker("created_by_isolated")) + + @pytest.mark.asyncio async def test_mtime_ns(db): assert isinstance(db.mtime_ns, int)