diff --git a/datasette/database.py b/datasette/database.py index 66d50ffa..e7e9527e 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -25,6 +25,7 @@ from .utils import ( table_columns, table_column_details, ) +from .utils.sql_analysis import SQLAnalysis, analyze_sql_tables from .utils.sqlite import sqlite_version from .inspect import inspect_hash @@ -301,6 +302,13 @@ class Database: # Threaded mode - send to write thread return await self._send_to_write_thread(fn, isolated_connection=True) + async def analyze_sql(self, sql, params=None) -> SQLAnalysis: + self._check_not_closed() + + return await self.execute_isolated_fn( + lambda conn: analyze_sql_tables(conn, sql, params, database_name=self.name) + ) + async def execute_write_fn(self, fn, block=True, transaction=True, request=None): self._check_not_closed() pending_events = [] diff --git a/datasette/utils/sql_analysis.py b/datasette/utils/sql_analysis.py new file mode 100644 index 00000000..b5317b62 --- /dev/null +++ b/datasette/utils/sql_analysis.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Literal + +from datasette.utils.sqlite import sqlite3 + +SQLTableOperation = Literal["read", "insert", "update", "delete"] + + +@dataclass(frozen=True) +class SQLTableAccess: + operation: SQLTableOperation + database: str | None + table: str + sqlite_schema: str | None + columns: tuple[str, ...] = () + source: str | None = None + + +@dataclass(frozen=True) +class SQLAnalysis: + table_accesses: tuple[SQLTableAccess, ...] + + +_ACTION_TO_OPERATION: dict[int, SQLTableOperation] = { + sqlite3.SQLITE_READ: "read", + sqlite3.SQLITE_INSERT: "insert", + sqlite3.SQLITE_UPDATE: "update", + sqlite3.SQLITE_DELETE: "delete", +} + + +def analyze_sql_tables( + conn, + sql: str, + params=None, + *, + database_name: str | None = None, + schema_to_database: dict[str, str] | None = None, +) -> SQLAnalysis: + """ + Return tables accessed by a SQL statement according to SQLite's authorizer. + + This function is synchronous and connection-based. It temporarily installs a + SQLite authorizer, prepares ``EXPLAIN ``, and returns the table access + callbacks observed while SQLite compiles the statement. + """ + accesses: dict[ + tuple[SQLTableOperation, str | None, str, str | None, str | None], set[str] + ] = {} + + def database_for_schema(sqlite_schema): + if schema_to_database and sqlite_schema in schema_to_database: + return schema_to_database[sqlite_schema] + if sqlite_schema == "main" and database_name is not None: + return database_name + return sqlite_schema + + def authorizer(action, arg1, arg2, sqlite_schema, source): + operation = _ACTION_TO_OPERATION.get(action) + if operation is None or arg1 is None: + return sqlite3.SQLITE_OK + + key = ( + operation, + database_for_schema(sqlite_schema), + arg1, + sqlite_schema, + source, + ) + columns = accesses.setdefault(key, set()) + if operation in ("read", "update") and arg2 is not None: + columns.add(arg2) + return sqlite3.SQLITE_OK + + conn.set_authorizer(authorizer) + try: + conn.execute("EXPLAIN " + sql, params if params is not None else {}).fetchall() + finally: + conn.set_authorizer(None) + + return SQLAnalysis( + table_accesses=tuple( + SQLTableAccess( + operation=operation, + database=database, + table=table, + sqlite_schema=sqlite_schema, + columns=tuple(sorted(columns)), + source=source, + ) + for ( + operation, + database, + table, + sqlite_schema, + source, + ), columns in accesses.items() + ) + ) diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index 75ae8d39..5481a398 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -688,6 +688,54 @@ async def test_execute_isolated(db, disable_threads): assert not await db.execute_isolated_fn(table_exists_checker("created_by_isolated")) +@pytest.mark.asyncio +async def test_analyze_sql(): + ds = Datasette(memory=True) + db = ds.add_memory_database("test_analyze_sql", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + + analysis = await db.analyze_sql("select name from dogs where id = ?", (1,)) + + assert [ + ( + access.operation, + access.database, + access.sqlite_schema, + access.table, + access.columns, + access.source, + ) + for access in analysis.table_accesses + ] == [ + ("read", "data", "main", "dogs", ("id", "name"), None), + ] + + +@pytest.mark.asyncio +async def test_analyze_sql_insert_select(): + ds = Datasette(memory=True) + db = ds.add_memory_database("test_analyze_sql_insert_select", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await db.execute_write("create table cats (id integer primary key, name text)") + + analysis = await db.analyze_sql("insert into dogs (name) select name from cats") + + assert { + ( + access.operation, + access.database, + access.sqlite_schema, + access.table, + access.columns, + access.source, + ) + for access in analysis.table_accesses + } == { + ("insert", "data", "main", "dogs", (), None), + ("read", "data", "main", "cats", ("name",), None), + } + + @pytest.mark.asyncio async def test_mtime_ns(db): assert isinstance(db.mtime_ns, int) diff --git a/tests/test_utils_sql_analysis.py b/tests/test_utils_sql_analysis.py new file mode 100644 index 00000000..c82fb04f --- /dev/null +++ b/tests/test_utils_sql_analysis.py @@ -0,0 +1,195 @@ +import pytest + +from datasette.utils.sqlite import sqlite3 +from datasette.utils.sql_analysis import analyze_sql_tables + + +@pytest.fixture +def conn(): + conn = sqlite3.connect(":memory:") + conn.executescript(""" + create table dogs (id integer primary key, name text, age integer); + create table cats (id integer primary key, name text); + create table log (message text); + create view dog_names as select id, name from dogs; + create trigger dogs_after_insert after insert on dogs begin + update cats set name = new.name where id = new.id; + insert into log (message) values (new.name); + end; + create trigger dog_names_instead_of_update instead of update on dog_names begin + update dogs set name = new.name where id = old.id; + end; + """) + try: + yield conn + finally: + conn.close() + + +def as_tuples(analysis): + return [ + ( + access.operation, + access.database, + access.sqlite_schema, + access.table, + access.columns, + access.source, + ) + for access in analysis.table_accesses + ] + + +def test_analyze_select_tables(conn): + analysis = analyze_sql_tables( + conn, + "select dogs.name, cats.name from dogs join cats on dogs.id = cats.id where dogs.age > ?", + (2,), + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("read", "data", "main", "cats", ("id", "name"), None), + ("read", "data", "main", "dogs", ("age", "id", "name"), None), + } + + +def test_analyze_uses_sqlite_schema_as_default_database(conn): + analysis = analyze_sql_tables(conn, "select name from dogs") + + assert set(as_tuples(analysis)) == { + ("read", "main", "main", "dogs", ("name",), None), + } + + +def test_analyze_insert_tables(conn): + analysis = analyze_sql_tables( + conn, + "insert into dogs (name, age) values (:name, :age)", + {"name": "Cleo", "age": 4}, + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("insert", "data", "main", "dogs", (), None), + ("read", "data", "main", "dogs", ("id", "name"), "dogs_after_insert"), + ("update", "data", "main", "cats", ("name",), "dogs_after_insert"), + ("read", "data", "main", "cats", ("id",), "dogs_after_insert"), + ("insert", "data", "main", "log", (), "dogs_after_insert"), + } + + +def test_analyze_update_tables(conn): + analysis = analyze_sql_tables( + conn, + "update dogs set age = age + 1 where name = ?", + ("Cleo",), + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("update", "data", "main", "dogs", ("age",), None), + ("read", "data", "main", "dogs", ("age", "name"), None), + } + + +def test_analyze_delete_tables(conn): + analysis = analyze_sql_tables( + conn, + "delete from dogs where name = ?", + ("Cleo",), + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("delete", "data", "main", "dogs", (), None), + ("read", "data", "main", "dogs", ("name",), None), + } + + +def test_analyze_insert_select_with_cte(conn): + analysis = analyze_sql_tables( + conn, + """ + with old_dogs as ( + select name from dogs where age > :age + ) + insert into cats (name) + select name from old_dogs + """, + {"age": 10}, + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("insert", "data", "main", "cats", (), None), + ("read", "data", "main", "dogs", ("age", "name"), "old_dogs"), + } + + +def test_analyze_view_with_instead_of_trigger(conn): + analysis = analyze_sql_tables( + conn, + "update dog_names set name = :name where id = :id", + {"name": "Zelda", "id": 1}, + database_name="data", + ) + + assert set(as_tuples(analysis)) == { + ("update", "data", "main", "dog_names", ("name",), None), + ("read", "data", "main", "dogs", ("id", "name"), "dog_names"), + ("read", "data", "main", "dog_names", ("id", "name"), "dog_names"), + ( + "read", + "data", + "main", + "dog_names", + ("id", "name"), + "dog_names_instead_of_update", + ), + ("update", "data", "main", "dogs", ("name",), "dog_names_instead_of_update"), + ("read", "data", "main", "dogs", ("id",), "dog_names_instead_of_update"), + } + + +def test_analyze_attached_database_tables(conn): + conn.execute("attach database ':memory:' as extra") + conn.execute("create table extra.people (id integer primary key, name text)") + + analysis = analyze_sql_tables( + conn, + "insert into extra.people (name) select name from dogs", + database_name="data", + schema_to_database={"extra": "extra_db"}, + ) + + assert set(as_tuples(analysis)) == { + ("insert", "extra_db", "extra", "people", (), None), + ("read", "data", "main", "dogs", ("name",), None), + } + + +def test_analyze_invalid_sql_cleans_up_authorizer(conn): + with pytest.raises(sqlite3.OperationalError): + analyze_sql_tables(conn, "insert into missing_table values (1)") + + conn.execute("select name from dogs").fetchall() + + +def test_analyze_clears_authorizer_on_error(): + class FakeConnection: + def __init__(self): + self.authorizers = [] + + def set_authorizer(self, authorizer): + self.authorizers.append(authorizer) + + def execute(self, sql, params): + raise sqlite3.OperationalError("bad SQL") + + conn = FakeConnection() + + with pytest.raises(sqlite3.OperationalError): + analyze_sql_tables(conn, "bad SQL") + + assert conn.authorizers[-1] is None