This commit is contained in:
Simon Willison 2021-08-16 11:36:53 -07:00
commit 91315e07a7

View file

@ -1093,44 +1093,59 @@ async def derive_named_parameters(db, sql):
def columns_for_query(conn, sql, params=None):
"""
Given a SQLite connection ``conn`` and a SQL query ``sql``,
returns a list of ``(table_name, column_name)`` pairs, one
per returned column. ``(None, None)`` if no table and column
could be derived.
Given a SQLite connection ``conn`` and a SQL query ``sql``, returns a list of
``(table_name, column_name)`` pairs corresponding to the columns that would be
returned by that SQL query.
Each pair indicates the source table and column for the returned column, or
``(None, None)`` if no table and column could be derived (e.g. for "select 1")
"""
if sql.lower().strip().startswith("explain"):
return []
rows = conn.execute("explain " + sql, params).fetchall()
opcodes = conn.execute("explain " + sql, params).fetchall()
table_rootpage_by_register = {
r["p1"]: r["p2"] for r in rows if r["opcode"] == "OpenRead"
r["p1"]: r["p2"] for r in opcodes if r["opcode"] == "OpenRead"
}
names_by_rootpage = dict(
conn.execute(
"select rootpage, name from sqlite_master where rootpage in ({})".format(
print(f"{table_rootpage_by_register=}")
names_and_types_by_rootpage = dict(
[(r[0], (r[1], r[2])) for r in conn.execute(
"select rootpage, name, type from sqlite_master where rootpage in ({})".format(
", ".join(map(str, table_rootpage_by_register.values()))
)
)
)]
)
print(f"{names_and_types_by_rootpage=}")
columns_by_column_register = {}
for row in rows:
if row["opcode"] in ("Rowid", "Column"):
addr, opcode, table_id, cid, column_register, p4, p5, comment = row
for opcode_row in opcodes:
if opcode_row["opcode"] in ("Rowid", "Column"):
addr, opcode, table_id, cid, column_register, p4, p5, comment = opcode_row
print(f"{table_id=} {cid=} {column_register=}")
table = None
try:
table = names_by_rootpage[table_rootpage_by_register[table_id]]
table = names_and_types_by_rootpage[table_rootpage_by_register[table_id]][0]
columns_by_column_register[column_register] = (table, cid)
except KeyError:
except KeyError as e:
print(" KeyError")
print(" ", e)
print(" table = names_and_types_by_rootpage[table_rootpage_by_register[table_id]][0]")
print(f" {names_and_types_by_rootpage=} {table_rootpage_by_register=} {table_id=}")
print(" columns_by_column_register[column_register] = (table, cid)")
print(f" {column_register=} = ({table=}, {cid=})")
pass
result_row = [dict(r) for r in rows if r["opcode"] == "ResultRow"][0]
registers = list(range(result_row["p1"], result_row["p1"] + result_row["p2"]))
result_row = [dict(r) for r in opcodes if r["opcode"] == "ResultRow"][0]
result_registers = list(range(result_row["p1"], result_row["p1"] + result_row["p2"]))
print(f"{result_registers=}")
print(f"{columns_by_column_register=}")
all_column_names = {}
for table in names_by_rootpage.values():
for (table, _) in names_and_types_by_rootpage.values():
table_xinfo = conn.execute("pragma table_xinfo({})".format(table)).fetchall()
for row in table_xinfo:
all_column_names[(table, row["cid"])] = row["name"]
for column_info in table_xinfo:
all_column_names[(table, column_info["cid"])] = column_info["name"]
print(f"{all_column_names=}")
final_output = []
for r in registers:
for register in result_registers:
try:
table, cid = columns_by_column_register[r]
table, cid = columns_by_column_register[register]
final_output.append((table, all_column_names[table, cid]))
except KeyError:
final_output.append((None, None))