From b87130a036752821353fb251abfbaea2c70edb13 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 13 Feb 2020 12:43:06 -0800 Subject: [PATCH] Table page partially works on PostgreSQL, refs #670 --- datasette/database.py | 19 +++++++--- datasette/facets.py | 38 ++++++++------------ datasette/postgresql_database.py | 60 ++++++++++++++++++++----------- datasette/views/table.py | 62 +++++++++++++++++++++++--------- 4 files changed, 115 insertions(+), 64 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index 875dac20..bbace949 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -78,6 +78,12 @@ class Database: """Executes sql against db_name in a thread""" page_size = page_size or self.ds.page_size + # Where are we? + import io, traceback + + stored_stack = io.StringIO() + traceback.print_stack(file=stored_stack) + def sql_operation_in_thread(conn): time_limit_ms = self.ds.sql_time_limit_ms if custom_time_limit and custom_time_limit < time_limit_ms: @@ -114,10 +120,15 @@ class Database: else: return Results(rows, False, cursor.description) - with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_against_connection_in_thread( - sql_operation_in_thread - ) + try: + with trace("sql", database=self.name, sql=sql.strip(), params=params): + results = await self.execute_against_connection_in_thread( + sql_operation_in_thread + ) + except Exception as e: + print(e) + print(stored_stack.getvalue()) + raise return results @property diff --git a/datasette/facets.py b/datasette/facets.py index 18558754..031c164e 100644 --- a/datasette/facets.py +++ b/datasette/facets.py @@ -73,7 +73,7 @@ class Facet: self, ds, request, - database, + db, sql=None, table=None, params=None, @@ -83,7 +83,7 @@ class Facet: assert table or sql, "Must provide either table= or sql=" self.ds = ds self.request = request - self.database = database + self.db = db # For foreign key expansion. Can be None for e.g. canned SQL queries: self.table = table self.sql = sql or "select * from [{}]".format(table) @@ -113,17 +113,16 @@ class Facet: async def get_columns(self, sql, params=None): # Detect column names using the "limit 0" trick return ( - await self.ds.execute( - self.database, "select * from ({}) limit 0".format(sql), params or [] + await self.db.execute( + "select * from ({}) as derived limit 0".format(sql), params or [] ) ).columns async def get_row_count(self): if self.row_count is None: self.row_count = ( - await self.ds.execute( - self.database, - "select count(*) from ({})".format(self.sql), + await self.db.execute( + "select count(*) from ({}) as derived".format(self.sql), self.params, ) ).rows[0][0] @@ -153,8 +152,7 @@ class ColumnFacet(Facet): ) distinct_values = None try: - distinct_values = await self.ds.execute( - self.database, + distinct_values = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -203,8 +201,7 @@ class ColumnFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, @@ -225,8 +222,8 @@ class ColumnFacet(Facet): if self.table: # Attempt to expand foreign keys into labels values = [row["value"] for row in facet_rows] - expanded = await self.ds.expand_foreign_keys( - self.database, self.table, column, values + expanded = await self.db.expand_foreign_keys( + self.table, column, values ) else: expanded = {} @@ -285,8 +282,7 @@ class ArrayFacet(Facet): column=escape_sqlite(column), sql=self.sql ) try: - results = await self.ds.execute( - self.database, + results = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -298,8 +294,7 @@ class ArrayFacet(Facet): # Now sanity check that first 100 arrays contain only strings first_100 = [ v[0] - for v in await self.ds.execute( - self.database, + for v in await self.db.execute( "select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format( column=escape_sqlite(column), sql=self.sql ), @@ -349,8 +344,7 @@ class ArrayFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, @@ -416,8 +410,7 @@ class DateFacet(Facet): column=escape_sqlite(column), sql=self.sql ) try: - results = await self.ds.execute( - self.database, + results = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -462,8 +455,7 @@ class DateFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, diff --git a/datasette/postgresql_database.py b/datasette/postgresql_database.py index 7dc1f5b9..4fd2ca1b 100644 --- a/datasette/postgresql_database.py +++ b/datasette/postgresql_database.py @@ -7,6 +7,10 @@ class PostgresqlResults: self.rows = rows self.truncated = truncated + @property + def description(self): + return [[c] for c in self.columns] + @property def columns(self): try: @@ -24,6 +28,8 @@ class PostgresqlResults: class PostgresqlDatabase: size = 0 is_mutable = False + is_memory = False + hash = None def __init__(self, ds, name, dsn): self.ds = ds @@ -65,7 +71,7 @@ class PostgresqlDatabase: return counts async def table_exists(self, table): - raise NotImplementedError + return table in await self.table_names() async def table_names(self): results = await self.execute( @@ -159,29 +165,41 @@ class PostgresqlDatabase: return [] async def get_table_definition(self, table, type_="table"): - table_definition_rows = list( - await self.execute( - "select sql from sqlite_master where name = :n and type=:t", - {"n": table, "t": type_}, + sql = """ + SELECT + 'CREATE TABLE ' || relname || E'\n(\n' || + array_to_string( + array_agg( + ' ' || column_name || ' ' || type || ' '|| not_null ) - ) - if not table_definition_rows: - return None - bits = [table_definition_rows[0][0] + ";"] - # Add on any indexes - index_rows = list( - await self.ds.execute( - self.name, - "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", - {"n": table}, - ) - ) - for index_row in index_rows: - bits.append(index_row[0] + ";") - return "\n".join(bits) + , E',\n' + ) || E'\n);\n' + from + ( + SELECT + c.relname, a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) as type, + case + when a.attnotnull + then 'NOT NULL' + else 'NULL' + END as not_null + FROM pg_class c, + pg_attribute a, + pg_type t + WHERE c.relname = $1 + AND a.attnum > 0 + AND a.attrelid = c.oid + AND a.atttypid = t.oid + ORDER BY a.attnum + ) as tabledefinition + group by relname; + """ + return await (await self.connection()).fetchval(sql, table) async def get_view_definition(self, view): - return await self.get_table_definition(view, "view") + # return await self.get_table_definition(view, "view") + return [] def __repr__(self): tags = [] diff --git a/datasette/views/table.py b/datasette/views/table.py index 54839344..d77a1259 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -5,6 +5,7 @@ import json import jinja2 from datasette.plugins import pm +from datasette.postgresql_database import PostgresqlDatabase from datasette.utils import ( CustomRow, QueryInterrupted, @@ -64,7 +65,12 @@ class Row: class RowTableShared(DataView): async def sortable_columns_for_table(self, database, table, use_rowid): - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) table_metadata = self.ds.table_metadata(database, table) if "sortable_columns" in table_metadata: sortable_columns = set(table_metadata["sortable_columns"]) @@ -77,7 +83,12 @@ class RowTableShared(DataView): async def expandable_columns(self, database, table): # Returns list of (fk_dict, label_column-or-None) pairs for that table expandables = [] - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) for fk in await db.foreign_keys_for_table(table): label_column = await db.label_column_for_table(fk["other_table"]) expandables.append((fk, label_column)) @@ -87,7 +98,12 @@ class RowTableShared(DataView): self, database, table, description, rows, link_column=False, truncate_cells=0 ): "Returns columns, rows for specified table - including fancy foreign key treatment" - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) table_metadata = self.ds.table_metadata(database, table) sortable_columns = await self.sortable_columns_for_table(database, table, True) columns = [ @@ -228,7 +244,15 @@ class TableView(RowTableShared): editable=False, canned_query=table, ) - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) + + print("Here we go, db = ", db) + is_view = bool(await db.get_view_definition(table)) table_exists = bool(await db.table_exists(table)) if not is_view and not table_exists: @@ -533,17 +557,13 @@ class TableView(RowTableShared): if request.raw_args.get("_timelimit"): extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) - results = await self.ds.execute( - database, sql, params, truncate=True, **extra_args - ) + results = await db.execute(sql, params, truncate=True, **extra_args) # Number of filtered rows in whole set: filtered_table_rows_count = None if count_sql: try: - count_rows = list( - await self.ds.execute(database, count_sql, from_sql_params) - ) + count_rows = list(await db.execute(count_sql, from_sql_params)) filtered_table_rows_count = count_rows[0][0] except QueryInterrupted: pass @@ -566,7 +586,7 @@ class TableView(RowTableShared): klass( self.ds, request, - database, + db, sql=sql_no_limit, params=params, table=table, @@ -584,7 +604,7 @@ class TableView(RowTableShared): facets_timed_out.extend(instance_facets_timed_out) # Figure out columns and rows for the query - columns = [r[0] for r in results.description] + columns = list(results.rows[0].keys()) rows = list(results.rows) # Expand labeled columns if requested @@ -781,7 +801,12 @@ class RowView(RowTableShared): async def data(self, request, database, hash, table, pk_path, default_labels=False): pk_values = urlsafe_components(pk_path) - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) pks = await db.primary_keys(table) use_rowid = not pks select = "*" @@ -795,7 +820,7 @@ class RowView(RowTableShared): params = {} for i, pk_value in enumerate(pk_values): params["p{}".format(i)] = pk_value - results = await self.ds.execute(database, sql, params, truncate=True) + results = await db.execute(sql, params, truncate=True) columns = [r[0] for r in results.description] rows = list(results.rows) if not rows: @@ -860,7 +885,12 @@ class RowView(RowTableShared): async def foreign_key_tables(self, database, table, pk_values): if len(pk_values) != 1: return [] - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) all_foreign_keys = await db.get_all_foreign_keys() foreign_keys = all_foreign_keys[table]["incoming"] if len(foreign_keys) == 0: @@ -876,7 +906,7 @@ class RowView(RowTableShared): ] ) try: - rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]})) + rows = list(await db.execute(sql, {"id": pk_values[0]})) except sqlite3.OperationalError: # Almost certainly hit the timeout return []