Fix tests I just broke

This commit is contained in:
Simon Willison 2026-05-28 09:03:10 -07:00
commit 2785fd29de
3 changed files with 65 additions and 37 deletions

View file

@ -193,6 +193,10 @@ _AUTHORIZER_ACTION_NAMES = {
}
def _allow_authorizer_action(*args):
return sqlite3.SQLITE_OK
def analyze_sql_tables(
conn,
sql: str,
@ -424,42 +428,59 @@ def analyze_sql_tables(
)
return sqlite3.SQLITE_OK
table_kind_cache: dict[tuple[str | None, str], SQLTableKind | None] = {}
conn.set_authorizer(authorizer)
try:
explain_rows = conn.execute(
"EXPLAIN " + sql, params if params is not None else {}
).fetchall()
# Passing None before these lookups leaves a failing callback installed
# on Python 3.10, so use a permissive callback until they are complete.
conn.set_authorizer(_allow_authorizer_action)
if not operations:
vacuum_row = next((row for row in explain_rows if row[1] == "Vacuum"), None)
if vacuum_row is not None:
schema_by_index = {
row[0]: row[1] for row in conn.execute("PRAGMA database_list")
}
sqlite_schema = schema_by_index.get(vacuum_row[2])
database = database_for_schema(sqlite_schema)
record(
"vacuum",
"database",
database=database,
table=None,
sqlite_schema=sqlite_schema,
target=database,
source=None,
)
else:
record(
"unknown",
"statement",
database=database_name,
table=None,
sqlite_schema=None,
target=None,
source=None,
)
for key in operations:
if (
key.target_type == "table"
and key.operation in {"read", "insert", "update", "delete"}
and key.table is not None
):
cache_key = (key.sqlite_schema, key.table)
if cache_key not in table_kind_cache:
table_kind_cache[cache_key] = sqlite_table_type(
conn, key.table, schema=key.sqlite_schema
)
finally:
conn.set_authorizer(None)
if not operations:
vacuum_row = next((row for row in explain_rows if row[1] == "Vacuum"), None)
if vacuum_row is not None:
schema_by_index = {
row[0]: row[1] for row in conn.execute("PRAGMA database_list")
}
sqlite_schema = schema_by_index.get(vacuum_row[2])
database = database_for_schema(sqlite_schema)
record(
"vacuum",
"database",
database=database,
table=None,
sqlite_schema=sqlite_schema,
target=database,
source=None,
)
else:
record(
"unknown",
"statement",
database=database_name,
table=None,
sqlite_schema=None,
target=None,
source=None,
)
has_schema_operation = any(
key.target_type in {"table", "index", "view", "trigger", "virtual-table"}
and key.operation in {"create", "alter", "drop"}
@ -502,8 +523,6 @@ def analyze_sql_tables(
return True
return False
table_kind_cache: dict[tuple[str | None, str], SQLTableKind | None] = {}
def table_kind_for(key: OperationKey) -> SQLTableKind | None:
if (
key.target_type != "table"
@ -511,12 +530,7 @@ def analyze_sql_tables(
or key.table is None
):
return None
cache_key = (key.sqlite_schema, key.table)
if cache_key not in table_kind_cache:
table_kind_cache[cache_key] = sqlite_table_type(
conn, key.table, schema=key.sqlite_schema
)
return table_kind_cache[cache_key]
return table_kind_cache[(key.sqlite_schema, key.table)]
return SQLAnalysis(
operations=tuple(