mirror of
https://github.com/simonw/datasette.git
synced 2026-06-04 08:07:01 +02:00
Database.analyze_sql(sql) method
Experimental, we may need this for the upcoming canned query work so that we can tell if a user should be able to save a writable canned query by confirming they have the right permissions to update the affected tables. Refs #2735
This commit is contained in:
parent
6cafdcb6fa
commit
a855a1acec
4 changed files with 350 additions and 0 deletions
|
|
@ -25,6 +25,7 @@ from .utils import (
|
|||
table_columns,
|
||||
table_column_details,
|
||||
)
|
||||
from .utils.sql_analysis import SQLAnalysis, analyze_sql_tables
|
||||
from .utils.sqlite import sqlite_version
|
||||
from .inspect import inspect_hash
|
||||
|
||||
|
|
@ -301,6 +302,13 @@ class Database:
|
|||
# Threaded mode - send to write thread
|
||||
return await self._send_to_write_thread(fn, isolated_connection=True)
|
||||
|
||||
async def analyze_sql(self, sql, params=None) -> SQLAnalysis:
|
||||
self._check_not_closed()
|
||||
|
||||
return await self.execute_isolated_fn(
|
||||
lambda conn: analyze_sql_tables(conn, sql, params, database_name=self.name)
|
||||
)
|
||||
|
||||
async def execute_write_fn(self, fn, block=True, transaction=True, request=None):
|
||||
self._check_not_closed()
|
||||
pending_events = []
|
||||
|
|
|
|||
99
datasette/utils/sql_analysis.py
Normal file
99
datasette/utils/sql_analysis.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from datasette.utils.sqlite import sqlite3
|
||||
|
||||
SQLTableOperation = Literal["read", "insert", "update", "delete"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SQLTableAccess:
|
||||
operation: SQLTableOperation
|
||||
database: str | None
|
||||
table: str
|
||||
sqlite_schema: str | None
|
||||
columns: tuple[str, ...] = ()
|
||||
source: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SQLAnalysis:
|
||||
table_accesses: tuple[SQLTableAccess, ...]
|
||||
|
||||
|
||||
_ACTION_TO_OPERATION: dict[int, SQLTableOperation] = {
|
||||
sqlite3.SQLITE_READ: "read",
|
||||
sqlite3.SQLITE_INSERT: "insert",
|
||||
sqlite3.SQLITE_UPDATE: "update",
|
||||
sqlite3.SQLITE_DELETE: "delete",
|
||||
}
|
||||
|
||||
|
||||
def analyze_sql_tables(
|
||||
conn,
|
||||
sql: str,
|
||||
params=None,
|
||||
*,
|
||||
database_name: str | None = None,
|
||||
schema_to_database: dict[str, str] | None = None,
|
||||
) -> SQLAnalysis:
|
||||
"""
|
||||
Return tables accessed by a SQL statement according to SQLite's authorizer.
|
||||
|
||||
This function is synchronous and connection-based. It temporarily installs a
|
||||
SQLite authorizer, prepares ``EXPLAIN <sql>``, and returns the table access
|
||||
callbacks observed while SQLite compiles the statement.
|
||||
"""
|
||||
accesses: dict[
|
||||
tuple[SQLTableOperation, str | None, str, str | None, str | None], set[str]
|
||||
] = {}
|
||||
|
||||
def database_for_schema(sqlite_schema):
|
||||
if schema_to_database and sqlite_schema in schema_to_database:
|
||||
return schema_to_database[sqlite_schema]
|
||||
if sqlite_schema == "main" and database_name is not None:
|
||||
return database_name
|
||||
return sqlite_schema
|
||||
|
||||
def authorizer(action, arg1, arg2, sqlite_schema, source):
|
||||
operation = _ACTION_TO_OPERATION.get(action)
|
||||
if operation is None or arg1 is None:
|
||||
return sqlite3.SQLITE_OK
|
||||
|
||||
key = (
|
||||
operation,
|
||||
database_for_schema(sqlite_schema),
|
||||
arg1,
|
||||
sqlite_schema,
|
||||
source,
|
||||
)
|
||||
columns = accesses.setdefault(key, set())
|
||||
if operation in ("read", "update") and arg2 is not None:
|
||||
columns.add(arg2)
|
||||
return sqlite3.SQLITE_OK
|
||||
|
||||
conn.set_authorizer(authorizer)
|
||||
try:
|
||||
conn.execute("EXPLAIN " + sql, params if params is not None else {}).fetchall()
|
||||
finally:
|
||||
conn.set_authorizer(None)
|
||||
|
||||
return SQLAnalysis(
|
||||
table_accesses=tuple(
|
||||
SQLTableAccess(
|
||||
operation=operation,
|
||||
database=database,
|
||||
table=table,
|
||||
sqlite_schema=sqlite_schema,
|
||||
columns=tuple(sorted(columns)),
|
||||
source=source,
|
||||
)
|
||||
for (
|
||||
operation,
|
||||
database,
|
||||
table,
|
||||
sqlite_schema,
|
||||
source,
|
||||
), columns in accesses.items()
|
||||
)
|
||||
)
|
||||
|
|
@ -688,6 +688,54 @@ async def test_execute_isolated(db, disable_threads):
|
|||
assert not await db.execute_isolated_fn(table_exists_checker("created_by_isolated"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_sql():
|
||||
ds = Datasette(memory=True)
|
||||
db = ds.add_memory_database("test_analyze_sql", name="data")
|
||||
await db.execute_write("create table dogs (id integer primary key, name text)")
|
||||
|
||||
analysis = await db.analyze_sql("select name from dogs where id = ?", (1,))
|
||||
|
||||
assert [
|
||||
(
|
||||
access.operation,
|
||||
access.database,
|
||||
access.sqlite_schema,
|
||||
access.table,
|
||||
access.columns,
|
||||
access.source,
|
||||
)
|
||||
for access in analysis.table_accesses
|
||||
] == [
|
||||
("read", "data", "main", "dogs", ("id", "name"), None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_sql_insert_select():
|
||||
ds = Datasette(memory=True)
|
||||
db = ds.add_memory_database("test_analyze_sql_insert_select", name="data")
|
||||
await db.execute_write("create table dogs (id integer primary key, name text)")
|
||||
await db.execute_write("create table cats (id integer primary key, name text)")
|
||||
|
||||
analysis = await db.analyze_sql("insert into dogs (name) select name from cats")
|
||||
|
||||
assert {
|
||||
(
|
||||
access.operation,
|
||||
access.database,
|
||||
access.sqlite_schema,
|
||||
access.table,
|
||||
access.columns,
|
||||
access.source,
|
||||
)
|
||||
for access in analysis.table_accesses
|
||||
} == {
|
||||
("insert", "data", "main", "dogs", (), None),
|
||||
("read", "data", "main", "cats", ("name",), None),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mtime_ns(db):
|
||||
assert isinstance(db.mtime_ns, int)
|
||||
|
|
|
|||
195
tests/test_utils_sql_analysis.py
Normal file
195
tests/test_utils_sql_analysis.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
import pytest
|
||||
|
||||
from datasette.utils.sqlite import sqlite3
|
||||
from datasette.utils.sql_analysis import analyze_sql_tables
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn():
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.executescript("""
|
||||
create table dogs (id integer primary key, name text, age integer);
|
||||
create table cats (id integer primary key, name text);
|
||||
create table log (message text);
|
||||
create view dog_names as select id, name from dogs;
|
||||
create trigger dogs_after_insert after insert on dogs begin
|
||||
update cats set name = new.name where id = new.id;
|
||||
insert into log (message) values (new.name);
|
||||
end;
|
||||
create trigger dog_names_instead_of_update instead of update on dog_names begin
|
||||
update dogs set name = new.name where id = old.id;
|
||||
end;
|
||||
""")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def as_tuples(analysis):
|
||||
return [
|
||||
(
|
||||
access.operation,
|
||||
access.database,
|
||||
access.sqlite_schema,
|
||||
access.table,
|
||||
access.columns,
|
||||
access.source,
|
||||
)
|
||||
for access in analysis.table_accesses
|
||||
]
|
||||
|
||||
|
||||
def test_analyze_select_tables(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"select dogs.name, cats.name from dogs join cats on dogs.id = cats.id where dogs.age > ?",
|
||||
(2,),
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("read", "data", "main", "cats", ("id", "name"), None),
|
||||
("read", "data", "main", "dogs", ("age", "id", "name"), None),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_uses_sqlite_schema_as_default_database(conn):
|
||||
analysis = analyze_sql_tables(conn, "select name from dogs")
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("read", "main", "main", "dogs", ("name",), None),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_insert_tables(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"insert into dogs (name, age) values (:name, :age)",
|
||||
{"name": "Cleo", "age": 4},
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("insert", "data", "main", "dogs", (), None),
|
||||
("read", "data", "main", "dogs", ("id", "name"), "dogs_after_insert"),
|
||||
("update", "data", "main", "cats", ("name",), "dogs_after_insert"),
|
||||
("read", "data", "main", "cats", ("id",), "dogs_after_insert"),
|
||||
("insert", "data", "main", "log", (), "dogs_after_insert"),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_update_tables(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"update dogs set age = age + 1 where name = ?",
|
||||
("Cleo",),
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("update", "data", "main", "dogs", ("age",), None),
|
||||
("read", "data", "main", "dogs", ("age", "name"), None),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_delete_tables(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"delete from dogs where name = ?",
|
||||
("Cleo",),
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("delete", "data", "main", "dogs", (), None),
|
||||
("read", "data", "main", "dogs", ("name",), None),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_insert_select_with_cte(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"""
|
||||
with old_dogs as (
|
||||
select name from dogs where age > :age
|
||||
)
|
||||
insert into cats (name)
|
||||
select name from old_dogs
|
||||
""",
|
||||
{"age": 10},
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("insert", "data", "main", "cats", (), None),
|
||||
("read", "data", "main", "dogs", ("age", "name"), "old_dogs"),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_view_with_instead_of_trigger(conn):
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"update dog_names set name = :name where id = :id",
|
||||
{"name": "Zelda", "id": 1},
|
||||
database_name="data",
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("update", "data", "main", "dog_names", ("name",), None),
|
||||
("read", "data", "main", "dogs", ("id", "name"), "dog_names"),
|
||||
("read", "data", "main", "dog_names", ("id", "name"), "dog_names"),
|
||||
(
|
||||
"read",
|
||||
"data",
|
||||
"main",
|
||||
"dog_names",
|
||||
("id", "name"),
|
||||
"dog_names_instead_of_update",
|
||||
),
|
||||
("update", "data", "main", "dogs", ("name",), "dog_names_instead_of_update"),
|
||||
("read", "data", "main", "dogs", ("id",), "dog_names_instead_of_update"),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_attached_database_tables(conn):
|
||||
conn.execute("attach database ':memory:' as extra")
|
||||
conn.execute("create table extra.people (id integer primary key, name text)")
|
||||
|
||||
analysis = analyze_sql_tables(
|
||||
conn,
|
||||
"insert into extra.people (name) select name from dogs",
|
||||
database_name="data",
|
||||
schema_to_database={"extra": "extra_db"},
|
||||
)
|
||||
|
||||
assert set(as_tuples(analysis)) == {
|
||||
("insert", "extra_db", "extra", "people", (), None),
|
||||
("read", "data", "main", "dogs", ("name",), None),
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_invalid_sql_cleans_up_authorizer(conn):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
analyze_sql_tables(conn, "insert into missing_table values (1)")
|
||||
|
||||
conn.execute("select name from dogs").fetchall()
|
||||
|
||||
|
||||
def test_analyze_clears_authorizer_on_error():
|
||||
class FakeConnection:
|
||||
def __init__(self):
|
||||
self.authorizers = []
|
||||
|
||||
def set_authorizer(self, authorizer):
|
||||
self.authorizers.append(authorizer)
|
||||
|
||||
def execute(self, sql, params):
|
||||
raise sqlite3.OperationalError("bad SQL")
|
||||
|
||||
conn = FakeConnection()
|
||||
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
analyze_sql_tables(conn, "bad SQL")
|
||||
|
||||
assert conn.authorizers[-1] is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue