Refactor to remove RowTableShared class, closes #1719

Refs #1715
This commit is contained in:
Simon Willison 2022-04-25 11:33:35 -07:00
commit 579f59dcec
3 changed files with 329 additions and 315 deletions

View file

@ -40,7 +40,8 @@ from .views.special import (
PermissionsDebugView,
MessagesDebugView,
)
from .views.table import RowView, TableView
from .views.table import TableView
from .views.row import RowView
from .renderer import json_renderer
from .url_builder import Urls
from .database import Database, QueryInterrupted

142
datasette/views/row.py Normal file
View file

@ -0,0 +1,142 @@
from datasette.utils.asgi import NotFound
from datasette.database import QueryInterrupted
from .base import DataView
from datasette.utils import (
tilde_decode,
urlsafe_components,
to_css_class,
escape_sqlite,
)
from .table import _sql_params_pks, display_columns_and_rows
class RowView(DataView):
name = "row"
async def data(self, request, default_labels=False):
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
await self.ds.ensure_permissions(
request.actor,
[
("view-table", (database, table)),
("view-database", database),
"view-instance",
],
)
pk_values = urlsafe_components(request.url_vars["pks"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
raise NotFound(f"Record not found: {pk_values}")
async def template_data():
display_columns, display_rows = await display_columns_and_rows(
self.ds,
database,
table,
results.description,
rows,
link_column=False,
truncate_cells=0,
)
for column in display_columns:
column["sortable"] = False
return {
"foreign_key_tables": await self.foreign_key_tables(
database, table, pk_values
),
"display_columns": display_columns,
"display_rows": display_rows,
"custom_table_templates": [
f"_table-{to_css_class(database)}-{to_css_class(table)}.html",
f"_table-row-{to_css_class(database)}-{to_css_class(table)}.html",
"_table.html",
],
"metadata": (self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {}),
}
data = {
"database": database,
"table": table,
"rows": rows,
"columns": columns,
"primary_keys": pks,
"primary_key_values": pk_values,
"units": self.ds.table_metadata(database, table).get("units", {}),
}
if "foreign_key_tables" in (request.args.get("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables(
database, table, pk_values
)
return (
data,
template_data,
(
f"row-{to_css_class(database)}-{to_css_class(table)}.html",
"row.html",
),
)
async def foreign_key_tables(self, database, table, pk_values):
if len(pk_values) != 1:
return []
db = self.ds.databases[database]
all_foreign_keys = await db.get_all_foreign_keys()
foreign_keys = all_foreign_keys[table]["incoming"]
if len(foreign_keys) == 0:
return []
sql = "select " + ", ".join(
[
"(select count(*) from {table} where {column}=:id)".format(
table=escape_sqlite(fk["other_table"]),
column=escape_sqlite(fk["other_column"]),
)
for fk in foreign_keys
]
)
try:
rows = list(await db.execute(sql, {"id": pk_values[0]}))
except QueryInterrupted:
# Almost certainly hit the timeout
return []
foreign_table_counts = dict(
zip(
[(fk["other_table"], fk["other_column"]) for fk in foreign_keys],
list(rows[0]),
)
)
foreign_key_tables = []
for fk in foreign_keys:
count = (
foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0
)
key = fk["other_column"]
if key.startswith("_"):
key += "__exact"
link = "{}?{}={}".format(
self.ds.urls.table(database, fk["other_table"]),
key,
",".join(pk_values),
)
foreign_key_tables.append({**fk, **{"count": count, "link": link}})
return foreign_key_tables

View file

@ -1,4 +1,3 @@
import urllib
import itertools
import json
@ -9,7 +8,6 @@ from datasette.database import QueryInterrupted
from datasette.utils import (
await_me_maybe,
CustomRow,
MultiParams,
append_querystring,
compound_keys_after_sql,
format_bytes,
@ -21,7 +19,6 @@ from datasette.utils import (
is_url,
path_from_row_pks,
path_with_added_args,
path_with_format,
path_with_removed_args,
path_with_replaced_args,
to_css_class,
@ -68,7 +65,9 @@ class Row:
return json.dumps(d, default=repr, indent=2)
class RowTableShared(DataView):
class TableView(DataView):
name = "table"
async def sortable_columns_for_table(self, database, table, use_rowid):
db = self.ds.databases[database]
table_metadata = self.ds.table_metadata(database, table)
@ -89,193 +88,6 @@ class RowTableShared(DataView):
expandables.append((fk, label_column))
return expandables
async def display_columns_and_rows(
self, database, table, description, rows, link_column=False, truncate_cells=0
):
"""Returns columns, rows for specified table - including fancy foreign key treatment"""
db = self.ds.databases[database]
table_metadata = self.ds.table_metadata(database, table)
column_descriptions = table_metadata.get("columns") or {}
column_details = {col.name: col for col in await db.table_column_details(table)}
sortable_columns = await self.sortable_columns_for_table(database, table, True)
pks = await db.primary_keys(table)
pks_for_display = pks
if not pks_for_display:
pks_for_display = ["rowid"]
columns = []
for r in description:
if r[0] == "rowid" and "rowid" not in column_details:
type_ = "integer"
notnull = 0
else:
type_ = column_details[r[0]].type
notnull = column_details[r[0]].notnull
columns.append(
{
"name": r[0],
"sortable": r[0] in sortable_columns,
"is_pk": r[0] in pks_for_display,
"type": type_,
"notnull": notnull,
"description": column_descriptions.get(r[0]),
}
)
column_to_foreign_key_table = {
fk["column"]: fk["other_table"]
for fk in await db.foreign_keys_for_table(table)
}
cell_rows = []
base_url = self.ds.setting("base_url")
for row in rows:
cells = []
# Unless we are a view, the first column is a link - either to the rowid
# or to the simple or compound primary key
if link_column:
is_special_link_column = len(pks) != 1
pk_path = path_from_row_pks(row, pks, not pks, False)
cells.append(
{
"column": pks[0] if len(pks) == 1 else "Link",
"value_type": "pk",
"is_special_link_column": is_special_link_column,
"raw": pk_path,
"value": markupsafe.Markup(
'<a href="{table_path}/{flat_pks_quoted}">{flat_pks}</a>'.format(
base_url=base_url,
table_path=self.ds.urls.table(database, table),
flat_pks=str(markupsafe.escape(pk_path)),
flat_pks_quoted=path_from_row_pks(row, pks, not pks),
)
),
}
)
for value, column_dict in zip(row, columns):
column = column_dict["name"]
if link_column and len(pks) == 1 and column == pks[0]:
# If there's a simple primary key, don't repeat the value as it's
# already shown in the link column.
continue
# First let the plugins have a go
# pylint: disable=no-member
plugin_display_value = None
for candidate in pm.hook.render_cell(
value=value,
column=column,
table=table,
database=database,
datasette=self.ds,
):
candidate = await await_me_maybe(candidate)
if candidate is not None:
plugin_display_value = candidate
break
if plugin_display_value:
display_value = plugin_display_value
elif isinstance(value, bytes):
formatted = format_bytes(len(value))
display_value = markupsafe.Markup(
'<a class="blob-download" href="{}"{}>&lt;Binary:&nbsp;{:,}&nbsp;byte{}&gt;</a>'.format(
self.ds.urls.row_blob(
database,
table,
path_from_row_pks(row, pks, not pks),
column,
),
' title="{}"'.format(formatted)
if "bytes" not in formatted
else "",
len(value),
"" if len(value) == 1 else "s",
)
)
elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row
label = value["label"]
value = value["value"]
# The table we link to depends on the column
other_table = column_to_foreign_key_table[column]
link_template = (
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
)
display_value = markupsafe.Markup(
link_template.format(
database=database,
base_url=base_url,
table=tilde_encode(other_table),
link_id=tilde_encode(str(value)),
id=str(markupsafe.escape(value)),
label=str(markupsafe.escape(label)) or "-",
)
)
elif value in ("", None):
display_value = markupsafe.Markup("&nbsp;")
elif is_url(str(value).strip()):
display_value = markupsafe.Markup(
'<a href="{url}">{url}</a>'.format(
url=markupsafe.escape(value.strip())
)
)
elif column in table_metadata.get("units", {}) and value != "":
# Interpret units using pint
value = value * ureg(table_metadata["units"][column])
# Pint uses floating point which sometimes introduces errors in the compact
# representation, which we have to round off to avoid ugliness. In the vast
# majority of cases this rounding will be inconsequential. I hope.
value = round(value.to_compact(), 6)
display_value = markupsafe.Markup(
f"{value:~P}".replace(" ", "&nbsp;")
)
else:
display_value = str(value)
if truncate_cells and len(display_value) > truncate_cells:
display_value = display_value[:truncate_cells] + "\u2026"
cells.append(
{
"column": column,
"value": display_value,
"raw": value,
"value_type": "none"
if value is None
else str(type(value).__name__),
}
)
cell_rows.append(Row(cells))
if link_column:
# Add the link column header.
# If it's a simple primary key, we have to remove and re-add that column name at
# the beginning of the header row.
first_column = None
if len(pks) == 1:
columns = [col for col in columns if col["name"] != pks[0]]
first_column = {
"name": pks[0],
"sortable": len(pks) == 1,
"is_pk": True,
"type": column_details[pks[0]].type,
"notnull": column_details[pks[0]].notnull,
}
else:
first_column = {
"name": "Link",
"sortable": False,
"is_pk": False,
"type": "",
"notnull": 0,
}
columns = [first_column] + columns
return columns, cell_rows
class TableView(RowTableShared):
name = "table"
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
try:
@ -807,13 +619,17 @@ class TableView(RowTableShared):
async def extra_template():
nonlocal sort
display_columns, display_rows = await self.display_columns_and_rows(
display_columns, display_rows = await display_columns_and_rows(
self.ds,
database,
table,
results.description,
rows,
link_column=not is_view,
truncate_cells=self.ds.setting("truncate_cells_html"),
sortable_columns=await self.sortable_columns_for_table(
database, table, use_rowid=True
),
)
metadata = (
(self.ds.metadata("databases") or {})
@ -948,132 +764,187 @@ async def _sql_params_pks(db, table, pk_values):
return sql, params, pks
class RowView(RowTableShared):
name = "row"
async def display_columns_and_rows(
datasette,
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
sortable_columns=None,
):
"""Returns columns, rows for specified table - including fancy foreign key treatment"""
sortable_columns = sortable_columns or set()
db = datasette.databases[database]
table_metadata = datasette.table_metadata(database, table)
column_descriptions = table_metadata.get("columns") or {}
column_details = {col.name: col for col in await db.table_column_details(table)}
pks = await db.primary_keys(table)
pks_for_display = pks
if not pks_for_display:
pks_for_display = ["rowid"]
async def data(self, request, default_labels=False):
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
await self.ds.ensure_permissions(
request.actor,
[
("view-table", (database, table)),
("view-database", database),
"view-instance",
],
)
pk_values = urlsafe_components(request.url_vars["pks"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
raise NotFound(f"Record not found: {pk_values}")
async def template_data():
display_columns, display_rows = await self.display_columns_and_rows(
database,
table,
results.description,
rows,
link_column=False,
truncate_cells=0,
)
for column in display_columns:
column["sortable"] = False
return {
"foreign_key_tables": await self.foreign_key_tables(
database, table, pk_values
),
"display_columns": display_columns,
"display_rows": display_rows,
"custom_table_templates": [
f"_table-{to_css_class(database)}-{to_css_class(table)}.html",
f"_table-row-{to_css_class(database)}-{to_css_class(table)}.html",
"_table.html",
],
"metadata": (self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {}),
columns = []
for r in description:
if r[0] == "rowid" and "rowid" not in column_details:
type_ = "integer"
notnull = 0
else:
type_ = column_details[r[0]].type
notnull = column_details[r[0]].notnull
columns.append(
{
"name": r[0],
"sortable": r[0] in sortable_columns,
"is_pk": r[0] in pks_for_display,
"type": type_,
"notnull": notnull,
"description": column_descriptions.get(r[0]),
}
data = {
"database": database,
"table": table,
"rows": rows,
"columns": columns,
"primary_keys": pks,
"primary_key_values": pk_values,
"units": self.ds.table_metadata(database, table).get("units", {}),
}
if "foreign_key_tables" in (request.args.get("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables(
database, table, pk_values
)
return (
data,
template_data,
(
f"row-{to_css_class(database)}-{to_css_class(table)}.html",
"row.html",
),
)
async def foreign_key_tables(self, database, table, pk_values):
if len(pk_values) != 1:
return []
db = self.ds.databases[database]
all_foreign_keys = await db.get_all_foreign_keys()
foreign_keys = all_foreign_keys[table]["incoming"]
if len(foreign_keys) == 0:
return []
column_to_foreign_key_table = {
fk["column"]: fk["other_table"] for fk in await db.foreign_keys_for_table(table)
}
sql = "select " + ", ".join(
[
"(select count(*) from {table} where {column}=:id)".format(
table=escape_sqlite(fk["other_table"]),
column=escape_sqlite(fk["other_column"]),
cell_rows = []
base_url = datasette.setting("base_url")
for row in rows:
cells = []
# Unless we are a view, the first column is a link - either to the rowid
# or to the simple or compound primary key
if link_column:
is_special_link_column = len(pks) != 1
pk_path = path_from_row_pks(row, pks, not pks, False)
cells.append(
{
"column": pks[0] if len(pks) == 1 else "Link",
"value_type": "pk",
"is_special_link_column": is_special_link_column,
"raw": pk_path,
"value": markupsafe.Markup(
'<a href="{table_path}/{flat_pks_quoted}">{flat_pks}</a>'.format(
base_url=base_url,
table_path=datasette.urls.table(database, table),
flat_pks=str(markupsafe.escape(pk_path)),
flat_pks_quoted=path_from_row_pks(row, pks, not pks),
)
),
}
)
for value, column_dict in zip(row, columns):
column = column_dict["name"]
if link_column and len(pks) == 1 and column == pks[0]:
# If there's a simple primary key, don't repeat the value as it's
# already shown in the link column.
continue
# First let the plugins have a go
# pylint: disable=no-member
plugin_display_value = None
for candidate in pm.hook.render_cell(
value=value,
column=column,
table=table,
database=database,
datasette=datasette,
):
candidate = await await_me_maybe(candidate)
if candidate is not None:
plugin_display_value = candidate
break
if plugin_display_value:
display_value = plugin_display_value
elif isinstance(value, bytes):
formatted = format_bytes(len(value))
display_value = markupsafe.Markup(
'<a class="blob-download" href="{}"{}>&lt;Binary:&nbsp;{:,}&nbsp;byte{}&gt;</a>'.format(
datasette.urls.row_blob(
database,
table,
path_from_row_pks(row, pks, not pks),
column,
),
' title="{}"'.format(formatted)
if "bytes" not in formatted
else "",
len(value),
"" if len(value) == 1 else "s",
)
)
for fk in foreign_keys
]
)
try:
rows = list(await db.execute(sql, {"id": pk_values[0]}))
except QueryInterrupted:
# Almost certainly hit the timeout
return []
elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row
label = value["label"]
value = value["value"]
# The table we link to depends on the column
other_table = column_to_foreign_key_table[column]
link_template = LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
display_value = markupsafe.Markup(
link_template.format(
database=database,
base_url=base_url,
table=tilde_encode(other_table),
link_id=tilde_encode(str(value)),
id=str(markupsafe.escape(value)),
label=str(markupsafe.escape(label)) or "-",
)
)
elif value in ("", None):
display_value = markupsafe.Markup("&nbsp;")
elif is_url(str(value).strip()):
display_value = markupsafe.Markup(
'<a href="{url}">{url}</a>'.format(
url=markupsafe.escape(value.strip())
)
)
elif column in table_metadata.get("units", {}) and value != "":
# Interpret units using pint
value = value * ureg(table_metadata["units"][column])
# Pint uses floating point which sometimes introduces errors in the compact
# representation, which we have to round off to avoid ugliness. In the vast
# majority of cases this rounding will be inconsequential. I hope.
value = round(value.to_compact(), 6)
display_value = markupsafe.Markup(f"{value:~P}".replace(" ", "&nbsp;"))
else:
display_value = str(value)
if truncate_cells and len(display_value) > truncate_cells:
display_value = display_value[:truncate_cells] + "\u2026"
foreign_table_counts = dict(
zip(
[(fk["other_table"], fk["other_column"]) for fk in foreign_keys],
list(rows[0]),
cells.append(
{
"column": column,
"value": display_value,
"raw": value,
"value_type": "none"
if value is None
else str(type(value).__name__),
}
)
)
foreign_key_tables = []
for fk in foreign_keys:
count = (
foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0
)
key = fk["other_column"]
if key.startswith("_"):
key += "__exact"
link = "{}?{}={}".format(
self.ds.urls.table(database, fk["other_table"]),
key,
",".join(pk_values),
)
foreign_key_tables.append({**fk, **{"count": count, "link": link}})
return foreign_key_tables
cell_rows.append(Row(cells))
if link_column:
# Add the link column header.
# If it's a simple primary key, we have to remove and re-add that column name at
# the beginning of the header row.
first_column = None
if len(pks) == 1:
columns = [col for col in columns if col["name"] != pks[0]]
first_column = {
"name": pks[0],
"sortable": len(pks) == 1,
"is_pk": True,
"type": column_details[pks[0]].type,
"notnull": column_details[pks[0]].notnull,
}
else:
first_column = {
"name": "Link",
"sortable": False,
"is_pk": False,
"type": "",
"notnull": 0,
}
columns = [first_column] + columns
return columns, cell_rows