Compare commits

...

3 commits

Author SHA1 Message Date
Simon Willison
dcc5cd425a Moved expand_foreign_keys method to Database
Refs #620

Also fixed a few places that were calling ds.execute() instead of
db.execute()
2019-11-17 22:33:52 -08:00
Simon Willison
eb61845141 Format with black 2019-11-17 22:24:55 -08:00
Simon Willison
c2779e5af0 Experimental WIP for #620
Example URL: /fixtures/facetable?_fk.on_earth=facet_cities.id
2019-11-16 23:10:59 -08:00
4 changed files with 82 additions and 57 deletions

View file

@ -352,43 +352,6 @@ class Datasette:
log_sql_errors=log_sql_errors, log_sql_errors=log_sql_errors,
) )
async def expand_foreign_keys(self, database, table, column, values):
"Returns dict mapping (column, value) -> label"
labeled_fks = {}
db = self.databases[database]
foreign_keys = await db.foreign_keys_for_table(table)
# Find the foreign_key for this column
try:
fk = [
foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column
][0]
except IndexError:
return {}
label_column = await db.label_column_for_table(fk["other_table"])
if not label_column:
return {(fk["column"], value): str(value) for value in values}
labeled_fks = {}
sql = """
select {other_column}, {label_column}
from {other_table}
where {other_column} in ({placeholders})
""".format(
other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(database, sql, list(set(values)))
except QueryInterrupted:
pass
else:
for id, value in results:
labeled_fks[(fk["column"], id)] = value
return labeled_fks
def absolute_url(self, request, path): def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path) url = urllib.parse.urljoin(request.url, path)
if url.startswith("http://") and self.config("force_https_urls"): if url.startswith("http://") and self.config("force_https_urls"):

View file

@ -10,6 +10,7 @@ from .utils import (
detect_fts, detect_fts,
detect_primary_keys, detect_primary_keys,
detect_spatialite, detect_spatialite,
escape_sqlite,
get_all_foreign_keys, get_all_foreign_keys,
get_outbound_foreign_keys, get_outbound_foreign_keys,
sqlite_timelimit, sqlite_timelimit,
@ -217,6 +218,42 @@ class Database:
lambda conn: get_outbound_foreign_keys(conn, table) lambda conn: get_outbound_foreign_keys(conn, table)
) )
async def expand_foreign_keys(self, table, column, values, fks=None):
"Returns dict mapping (column, value) -> label"
labeled_fks = {}
foreign_keys = fks or await self.foreign_keys_for_table(table)
# Find the foreign_key for this column
try:
fk = [
foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column
][0]
except IndexError:
return {}
label_column = await self.label_column_for_table(fk["other_table"])
if not label_column:
return {(fk["column"], value): str(value) for value in values}
labeled_fks = {}
sql = """
select {other_column}, {label_column}
from {other_table}
where {other_column} in ({placeholders})
""".format(
other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(sql, list(set(values)))
except QueryInterrupted:
pass
else:
for id, value in results:
labeled_fks[(fk["column"], id)] = value
return labeled_fks
async def hidden_table_names(self): async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables # Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [ hidden_tables = [

View file

@ -139,6 +139,7 @@ class ColumnFacet(Facet):
facet_size = self.ds.config("default_facet_size") facet_size = self.ds.config("default_facet_size")
suggested_facets = [] suggested_facets = []
already_enabled = [c["config"]["simple"] for c in self.get_configs()] already_enabled = [c["config"]["simple"] for c in self.get_configs()]
database = self.ds.databases[self.database]
for column in columns: for column in columns:
if column in already_enabled: if column in already_enabled:
continue continue
@ -152,8 +153,7 @@ class ColumnFacet(Facet):
) )
distinct_values = None distinct_values = None
try: try:
distinct_values = await self.ds.execute( distinct_values = await database.execute(
self.database,
suggested_facet_sql, suggested_facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -182,6 +182,7 @@ class ColumnFacet(Facet):
async def facet_results(self): async def facet_results(self):
facet_results = {} facet_results = {}
facets_timed_out = [] facets_timed_out = []
database = self.ds.databases[self.database]
qs_pairs = self.get_querystring_pairs() qs_pairs = self.get_querystring_pairs()
@ -200,8 +201,7 @@ class ColumnFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
) )
try: try:
facet_rows_results = await self.ds.execute( facet_rows_results = await database.execute(
self.database,
facet_sql, facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -222,8 +222,8 @@ class ColumnFacet(Facet):
if self.table: if self.table:
# Attempt to expand foreign keys into labels # Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows] values = [row["value"] for row in facet_rows]
expanded = await self.ds.expand_foreign_keys( expanded = await database.expand_foreign_keys(
self.database, self.table, column, values self.table, column, values
) )
else: else:
expanded = {} expanded = {}

View file

@ -74,17 +74,42 @@ class RowTableShared(DataView):
sortable_columns.add("rowid") sortable_columns.add("rowid")
return sortable_columns return sortable_columns
async def expandable_columns(self, database, table): async def foreign_keys_for_table(self, request, database, table):
db = self.ds.databases[database]
fks = await db.foreign_keys_for_table(table)
# Handle ?_fk.article_id=articles.id querystring arguments
for key, value in request.args.items():
if key.startswith("_fk."):
value = value[0]
column = key.split("_fk.", 1)[1]
other_table, other_column = value.split(".", 1)
# {'other_table': '...', 'column': '...', 'other_column': 'id'}
expandable_fk = {
"other_table": other_table,
"column": column,
"other_column": other_column,
}
fks.append(expandable_fk)
return fks
async def expandable_columns(self, request, database, table):
# Returns list of (fk_dict, label_column-or-None) pairs for that table # Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = [] expandables = []
db = self.ds.databases[database] db = self.ds.databases[database]
for fk in await db.foreign_keys_for_table(table): for fk in await self.foreign_keys_for_table(request, database, table):
label_column = await db.label_column_for_table(fk["other_table"]) label_column = await db.label_column_for_table(fk["other_table"])
expandables.append((fk, label_column)) expandables.append((fk, label_column))
return expandables return expandables
async def display_columns_and_rows( async def display_columns_and_rows(
self, database, table, description, rows, link_column=False, truncate_cells=0 self,
request,
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
): ):
"Returns columns, rows for specified table - including fancy foreign key treatment" "Returns columns, rows for specified table - including fancy foreign key treatment"
db = self.ds.databases[database] db = self.ds.databases[database]
@ -96,7 +121,7 @@ class RowTableShared(DataView):
pks = await db.primary_keys(table) pks = await db.primary_keys(table)
column_to_foreign_key_table = { column_to_foreign_key_table = {
fk["column"]: fk["other_table"] fk["column"]: fk["other_table"]
for fk in await db.foreign_keys_for_table(table) for fk in await self.foreign_keys_for_table(request, database, table)
} }
cell_rows = [] cell_rows = []
@ -533,17 +558,13 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"): if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
results = await self.ds.execute( results = await db.execute(sql, params, truncate=True, **extra_args)
database, sql, params, truncate=True, **extra_args
)
# Number of filtered rows in whole set: # Number of filtered rows in whole set:
filtered_table_rows_count = None filtered_table_rows_count = None
if count_sql: if count_sql:
try: try:
count_rows = list( count_rows = list(await db.execute(count_sql, from_sql_params))
await self.ds.execute(database, count_sql, from_sql_params)
)
filtered_table_rows_count = count_rows[0][0] filtered_table_rows_count = count_rows[0][0]
except QueryInterrupted: except QueryInterrupted:
pass pass
@ -593,7 +614,7 @@ class TableView(RowTableShared):
# Expand labeled columns if requested # Expand labeled columns if requested
expanded_columns = [] expanded_columns = []
expandable_columns = await self.expandable_columns(database, table) expandable_columns = await self.expandable_columns(request, database, table)
columns_to_expand = None columns_to_expand = None
try: try:
all_labels = value_as_boolean(special_args.get("_labels", "")) all_labels = value_as_boolean(special_args.get("_labels", ""))
@ -618,7 +639,9 @@ class TableView(RowTableShared):
values = [row[column_index] for row in rows] values = [row[column_index] for row in rows]
# Expand them # Expand them
expanded_labels.update( expanded_labels.update(
await self.ds.expand_foreign_keys(database, table, column, values) await db.expand_foreign_keys(
table, column, values, fks=[p[0] for p in expandable_columns],
)
) )
if expanded_labels: if expanded_labels:
# Rewrite the rows # Rewrite the rows
@ -693,6 +716,7 @@ class TableView(RowTableShared):
async def extra_template(): async def extra_template():
display_columns, display_rows = await self.display_columns_and_rows( display_columns, display_rows = await self.display_columns_and_rows(
request,
database, database,
table, table,
results.description, results.description,
@ -799,7 +823,7 @@ class RowView(RowTableShared):
params = {} params = {}
for i, pk_value in enumerate(pk_values): for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value params["p{}".format(i)] = pk_value
results = await self.ds.execute(database, sql, params, truncate=True) results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description] columns = [r[0] for r in results.description]
rows = list(results.rows) rows = list(results.rows)
if not rows: if not rows:
@ -807,6 +831,7 @@ class RowView(RowTableShared):
async def template_data(): async def template_data():
display_columns, display_rows = await self.display_columns_and_rows( display_columns, display_rows = await self.display_columns_and_rows(
request,
database, database,
table, table,
results.description, results.description,
@ -880,7 +905,7 @@ class RowView(RowTableShared):
] ]
) )
try: try:
rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]})) rows = list(await db.execute(sql, {"id": pk_values[0]}))
except sqlite3.OperationalError: except sqlite3.OperationalError:
# Almost certainly hit the timeout # Almost certainly hit the timeout
return [] return []