From 1db116e20eda43c95d3c60a82548e355862f7212 Mon Sep 17 00:00:00 2001
From: Simon Willison
Date: Thu, 12 Aug 2021 17:47:40 -0700
Subject: [PATCH] WIP extra query column information, refs #1293
---
datasette/templates/query.html | 2 ++
datasette/utils/__init__.py | 41 ++++++++++++++++++++++++++++++++++
datasette/views/database.py | 10 +++++++++
3 files changed, 53 insertions(+)
diff --git a/datasette/templates/query.html b/datasette/templates/query.html
index 75f7f1b1..9fe1d4f5 100644
--- a/datasette/templates/query.html
+++ b/datasette/templates/query.html
@@ -67,6 +67,8 @@
+extra_column_info: {{ extra_column_info }}
+
{% if display_rows %}
This data as {% for name, url in renderers.items() %}{{ name }}{{ ", " if not loop.last }}{% endfor %}, CSV
diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py
index 70ac8976..69c72566 100644
--- a/datasette/utils/__init__.py
+++ b/datasette/utils/__init__.py
@@ -1089,3 +1089,44 @@ async def derive_named_parameters(db, sql):
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"]
except sqlite3.DatabaseError:
return possible_params
+
+
+def columns_for_query(conn, sql, params=None):
+ """
+ Given a SQLite connection ``conn`` and a SQL query ``sql``,
+ returns a list of ``(table_name, column_name)`` pairs, one
+ per returned column. ``(None, None)`` if no table and column
+ could be derived.
+ """
+ rows = conn.execute("explain " + sql, params).fetchall()
+ table_rootpage_by_register = {
+ r["p1"]: r["p2"] for r in rows if r["opcode"] == "OpenRead"
+ }
+ names_by_rootpage = dict(
+ conn.execute(
+ "select rootpage, name from sqlite_master where rootpage in ({})".format(
+ ", ".join(map(str, table_rootpage_by_register.values()))
+ )
+ )
+ )
+ columns_by_column_register = {}
+ for row in rows:
+ if row["opcode"] in ("Rowid", "Column"):
+ addr, opcode, table_id, cid, column_register, p4, p5, comment = row
+ table = names_by_rootpage[table_rootpage_by_register[table_id]]
+ columns_by_column_register[column_register] = (table, cid)
+ result_row = [dict(r) for r in rows if r["opcode"] == "ResultRow"][0]
+ registers = list(range(result_row["p1"], result_row["p1"] + result_row["p2"]))
+ all_column_names = {}
+ for table in names_by_rootpage.values():
+ table_xinfo = conn.execute("pragma table_xinfo({})".format(table)).fetchall()
+ for row in table_xinfo:
+ all_column_names[(table, row["cid"])] = row["name"]
+ final_output = []
+ for r in registers:
+ try:
+ table, cid = columns_by_column_register[r]
+ final_output.append((table, all_column_names[table, cid]))
+ except KeyError:
+ final_output.append((None, None))
+ return final_output
diff --git a/datasette/views/database.py b/datasette/views/database.py
index 7c36034c..7b1f1923 100644
--- a/datasette/views/database.py
+++ b/datasette/views/database.py
@@ -10,6 +10,7 @@ import markupsafe
from datasette.utils import (
await_me_maybe,
check_visibility,
+ columns_for_query,
derive_named_parameters,
to_css_class,
validate_sql_select,
@@ -248,6 +249,8 @@ class QueryView(DataView):
query_error = None
+ extra_column_info = None
+
# Execute query - as write or as read
if write:
if request.method == "POST":
@@ -334,6 +337,12 @@ class QueryView(DataView):
database, sql, params_for_query, truncate=True, **extra_args
)
columns = [r[0] for r in results.description]
+
+ # Try to figure out extra column information
+ db = self.ds.get_database(database)
+ extra_column_info = await db.execute_fn(
+ lambda conn: columns_for_query(conn, sql, params_for_query)
+ )
except sqlite3.DatabaseError as e:
query_error = e
results = None
@@ -462,6 +471,7 @@ class QueryView(DataView):
"show_hide_text": show_hide_text,
"show_hide_hidden": markupsafe.Markup(show_hide_hidden),
"hide_sql": hide_sql,
+ "extra_column_info": extra_column_info,
}
return (