Table page partially works on PostgreSQL, refs #670

This commit is contained in:
Simon Willison 2020-02-13 12:43:06 -08:00
commit b87130a036
4 changed files with 115 additions and 64 deletions

View file

@ -78,6 +78,12 @@ class Database:
"""Executes sql against db_name in a thread""" """Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size page_size = page_size or self.ds.page_size
# Where are we?
import io, traceback
stored_stack = io.StringIO()
traceback.print_stack(file=stored_stack)
def sql_operation_in_thread(conn): def sql_operation_in_thread(conn):
time_limit_ms = self.ds.sql_time_limit_ms time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms: if custom_time_limit and custom_time_limit < time_limit_ms:
@ -114,10 +120,15 @@ class Database:
else: else:
return Results(rows, False, cursor.description) return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params): try:
results = await self.execute_against_connection_in_thread( with trace("sql", database=self.name, sql=sql.strip(), params=params):
sql_operation_in_thread results = await self.execute_against_connection_in_thread(
) sql_operation_in_thread
)
except Exception as e:
print(e)
print(stored_stack.getvalue())
raise
return results return results
@property @property

View file

@ -73,7 +73,7 @@ class Facet:
self, self,
ds, ds,
request, request,
database, db,
sql=None, sql=None,
table=None, table=None,
params=None, params=None,
@ -83,7 +83,7 @@ class Facet:
assert table or sql, "Must provide either table= or sql=" assert table or sql, "Must provide either table= or sql="
self.ds = ds self.ds = ds
self.request = request self.request = request
self.database = database self.db = db
# For foreign key expansion. Can be None for e.g. canned SQL queries: # For foreign key expansion. Can be None for e.g. canned SQL queries:
self.table = table self.table = table
self.sql = sql or "select * from [{}]".format(table) self.sql = sql or "select * from [{}]".format(table)
@ -113,17 +113,16 @@ class Facet:
async def get_columns(self, sql, params=None): async def get_columns(self, sql, params=None):
# Detect column names using the "limit 0" trick # Detect column names using the "limit 0" trick
return ( return (
await self.ds.execute( await self.db.execute(
self.database, "select * from ({}) limit 0".format(sql), params or [] "select * from ({}) as derived limit 0".format(sql), params or []
) )
).columns ).columns
async def get_row_count(self): async def get_row_count(self):
if self.row_count is None: if self.row_count is None:
self.row_count = ( self.row_count = (
await self.ds.execute( await self.db.execute(
self.database, "select count(*) from ({}) as derived".format(self.sql),
"select count(*) from ({})".format(self.sql),
self.params, self.params,
) )
).rows[0][0] ).rows[0][0]
@ -153,8 +152,7 @@ class ColumnFacet(Facet):
) )
distinct_values = None distinct_values = None
try: try:
distinct_values = await self.ds.execute( distinct_values = await self.db.execute(
self.database,
suggested_facet_sql, suggested_facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -203,8 +201,7 @@ class ColumnFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
) )
try: try:
facet_rows_results = await self.ds.execute( facet_rows_results = await self.db.execute(
self.database,
facet_sql, facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -225,8 +222,8 @@ class ColumnFacet(Facet):
if self.table: if self.table:
# Attempt to expand foreign keys into labels # Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows] values = [row["value"] for row in facet_rows]
expanded = await self.ds.expand_foreign_keys( expanded = await self.db.expand_foreign_keys(
self.database, self.table, column, values self.table, column, values
) )
else: else:
expanded = {} expanded = {}
@ -285,8 +282,7 @@ class ArrayFacet(Facet):
column=escape_sqlite(column), sql=self.sql column=escape_sqlite(column), sql=self.sql
) )
try: try:
results = await self.ds.execute( results = await self.db.execute(
self.database,
suggested_facet_sql, suggested_facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -298,8 +294,7 @@ class ArrayFacet(Facet):
# Now sanity check that first 100 arrays contain only strings # Now sanity check that first 100 arrays contain only strings
first_100 = [ first_100 = [
v[0] v[0]
for v in await self.ds.execute( for v in await self.db.execute(
self.database,
"select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format( "select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format(
column=escape_sqlite(column), sql=self.sql column=escape_sqlite(column), sql=self.sql
), ),
@ -349,8 +344,7 @@ class ArrayFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
) )
try: try:
facet_rows_results = await self.ds.execute( facet_rows_results = await self.db.execute(
self.database,
facet_sql, facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -416,8 +410,7 @@ class DateFacet(Facet):
column=escape_sqlite(column), sql=self.sql column=escape_sqlite(column), sql=self.sql
) )
try: try:
results = await self.ds.execute( results = await self.db.execute(
self.database,
suggested_facet_sql, suggested_facet_sql,
self.params, self.params,
truncate=False, truncate=False,
@ -462,8 +455,7 @@ class DateFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
) )
try: try:
facet_rows_results = await self.ds.execute( facet_rows_results = await self.db.execute(
self.database,
facet_sql, facet_sql,
self.params, self.params,
truncate=False, truncate=False,

View file

@ -7,6 +7,10 @@ class PostgresqlResults:
self.rows = rows self.rows = rows
self.truncated = truncated self.truncated = truncated
@property
def description(self):
return [[c] for c in self.columns]
@property @property
def columns(self): def columns(self):
try: try:
@ -24,6 +28,8 @@ class PostgresqlResults:
class PostgresqlDatabase: class PostgresqlDatabase:
size = 0 size = 0
is_mutable = False is_mutable = False
is_memory = False
hash = None
def __init__(self, ds, name, dsn): def __init__(self, ds, name, dsn):
self.ds = ds self.ds = ds
@ -65,7 +71,7 @@ class PostgresqlDatabase:
return counts return counts
async def table_exists(self, table): async def table_exists(self, table):
raise NotImplementedError return table in await self.table_names()
async def table_names(self): async def table_names(self):
results = await self.execute( results = await self.execute(
@ -159,29 +165,41 @@ class PostgresqlDatabase:
return [] return []
async def get_table_definition(self, table, type_="table"): async def get_table_definition(self, table, type_="table"):
table_definition_rows = list( sql = """
await self.execute( SELECT
"select sql from sqlite_master where name = :n and type=:t", 'CREATE TABLE ' || relname || E'\n(\n' ||
{"n": table, "t": type_}, array_to_string(
array_agg(
' ' || column_name || ' ' || type || ' '|| not_null
) )
) , E',\n'
if not table_definition_rows: ) || E'\n);\n'
return None from
bits = [table_definition_rows[0][0] + ";"] (
# Add on any indexes SELECT
index_rows = list( c.relname, a.attname AS column_name,
await self.ds.execute( pg_catalog.format_type(a.atttypid, a.atttypmod) as type,
self.name, case
"select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", when a.attnotnull
{"n": table}, then 'NOT NULL'
) else 'NULL'
) END as not_null
for index_row in index_rows: FROM pg_class c,
bits.append(index_row[0] + ";") pg_attribute a,
return "\n".join(bits) pg_type t
WHERE c.relname = $1
AND a.attnum > 0
AND a.attrelid = c.oid
AND a.atttypid = t.oid
ORDER BY a.attnum
) as tabledefinition
group by relname;
"""
return await (await self.connection()).fetchval(sql, table)
async def get_view_definition(self, view): async def get_view_definition(self, view):
return await self.get_table_definition(view, "view") # return await self.get_table_definition(view, "view")
return []
def __repr__(self): def __repr__(self):
tags = [] tags = []

View file

@ -5,6 +5,7 @@ import json
import jinja2 import jinja2
from datasette.plugins import pm from datasette.plugins import pm
from datasette.postgresql_database import PostgresqlDatabase
from datasette.utils import ( from datasette.utils import (
CustomRow, CustomRow,
QueryInterrupted, QueryInterrupted,
@ -64,7 +65,12 @@ class Row:
class RowTableShared(DataView): class RowTableShared(DataView):
async def sortable_columns_for_table(self, database, table, use_rowid): async def sortable_columns_for_table(self, database, table, use_rowid):
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table) table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata: if "sortable_columns" in table_metadata:
sortable_columns = set(table_metadata["sortable_columns"]) sortable_columns = set(table_metadata["sortable_columns"])
@ -77,7 +83,12 @@ class RowTableShared(DataView):
async def expandable_columns(self, database, table): async def expandable_columns(self, database, table):
# Returns list of (fk_dict, label_column-or-None) pairs for that table # Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = [] expandables = []
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
for fk in await db.foreign_keys_for_table(table): for fk in await db.foreign_keys_for_table(table):
label_column = await db.label_column_for_table(fk["other_table"]) label_column = await db.label_column_for_table(fk["other_table"])
expandables.append((fk, label_column)) expandables.append((fk, label_column))
@ -87,7 +98,12 @@ class RowTableShared(DataView):
self, database, table, description, rows, link_column=False, truncate_cells=0 self, database, table, description, rows, link_column=False, truncate_cells=0
): ):
"Returns columns, rows for specified table - including fancy foreign key treatment" "Returns columns, rows for specified table - including fancy foreign key treatment"
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table) table_metadata = self.ds.table_metadata(database, table)
sortable_columns = await self.sortable_columns_for_table(database, table, True) sortable_columns = await self.sortable_columns_for_table(database, table, True)
columns = [ columns = [
@ -228,7 +244,15 @@ class TableView(RowTableShared):
editable=False, editable=False,
canned_query=table, canned_query=table,
) )
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
print("Here we go, db = ", db)
is_view = bool(await db.get_view_definition(table)) is_view = bool(await db.get_view_definition(table))
table_exists = bool(await db.table_exists(table)) table_exists = bool(await db.table_exists(table))
if not is_view and not table_exists: if not is_view and not table_exists:
@ -533,17 +557,13 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"): if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
results = await self.ds.execute( results = await db.execute(sql, params, truncate=True, **extra_args)
database, sql, params, truncate=True, **extra_args
)
# Number of filtered rows in whole set: # Number of filtered rows in whole set:
filtered_table_rows_count = None filtered_table_rows_count = None
if count_sql: if count_sql:
try: try:
count_rows = list( count_rows = list(await db.execute(count_sql, from_sql_params))
await self.ds.execute(database, count_sql, from_sql_params)
)
filtered_table_rows_count = count_rows[0][0] filtered_table_rows_count = count_rows[0][0]
except QueryInterrupted: except QueryInterrupted:
pass pass
@ -566,7 +586,7 @@ class TableView(RowTableShared):
klass( klass(
self.ds, self.ds,
request, request,
database, db,
sql=sql_no_limit, sql=sql_no_limit,
params=params, params=params,
table=table, table=table,
@ -584,7 +604,7 @@ class TableView(RowTableShared):
facets_timed_out.extend(instance_facets_timed_out) facets_timed_out.extend(instance_facets_timed_out)
# Figure out columns and rows for the query # Figure out columns and rows for the query
columns = [r[0] for r in results.description] columns = list(results.rows[0].keys())
rows = list(results.rows) rows = list(results.rows)
# Expand labeled columns if requested # Expand labeled columns if requested
@ -781,7 +801,12 @@ class RowView(RowTableShared):
async def data(self, request, database, hash, table, pk_path, default_labels=False): async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path) pk_values = urlsafe_components(pk_path)
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
pks = await db.primary_keys(table) pks = await db.primary_keys(table)
use_rowid = not pks use_rowid = not pks
select = "*" select = "*"
@ -795,7 +820,7 @@ class RowView(RowTableShared):
params = {} params = {}
for i, pk_value in enumerate(pk_values): for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value params["p{}".format(i)] = pk_value
results = await self.ds.execute(database, sql, params, truncate=True) results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description] columns = [r[0] for r in results.description]
rows = list(results.rows) rows = list(results.rows)
if not rows: if not rows:
@ -860,7 +885,12 @@ class RowView(RowTableShared):
async def foreign_key_tables(self, database, table, pk_values): async def foreign_key_tables(self, database, table, pk_values):
if len(pk_values) != 1: if len(pk_values) != 1:
return [] return []
db = self.ds.databases[database] # db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
all_foreign_keys = await db.get_all_foreign_keys() all_foreign_keys = await db.get_all_foreign_keys()
foreign_keys = all_foreign_keys[table]["incoming"] foreign_keys = all_foreign_keys[table]["incoming"]
if len(foreign_keys) == 0: if len(foreign_keys) == 0:
@ -876,7 +906,7 @@ class RowView(RowTableShared):
] ]
) )
try: try:
rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]})) rows = list(await db.execute(sql, {"id": pk_values[0]}))
except sqlite3.OperationalError: except sqlite3.OperationalError:
# Almost certainly hit the timeout # Almost certainly hit the timeout
return [] return []