From 32a2f5793a5276c1033bb96b79b9ee1c0e748219 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 13 Feb 2020 12:15:37 -0800 Subject: [PATCH 1/2] Start of PostgreSQL prototype, refs #670 This prototype demonstrates the database page working against a hard-coded connection string to a PostgreSQL database. It lists tables and their columns and their row count., --- datasette/postgresql_database.py | 199 +++++++++++++++++++++++++++++++ datasette/views/database.py | 8 +- 2 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 datasette/postgresql_database.py diff --git a/datasette/postgresql_database.py b/datasette/postgresql_database.py new file mode 100644 index 00000000..7dc1f5b9 --- /dev/null +++ b/datasette/postgresql_database.py @@ -0,0 +1,199 @@ +from .utils import Results +import asyncpg + + +class PostgresqlResults: + def __init__(self, rows, truncated): + self.rows = rows + self.truncated = truncated + + @property + def columns(self): + try: + return list(self.rows[0].keys()) + except IndexError: + return [] + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + +class PostgresqlDatabase: + size = 0 + is_mutable = False + + def __init__(self, ds, name, dsn): + self.ds = ds + self.name = name + self.dsn = dsn + self._connection = None + + async def connection(self): + if self._connection is None: + self._connection = await asyncpg.connect(self.dsn) + return self._connection + + async def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + """Executes sql against db_name in a thread""" + print(sql, params) + rows = await (await self.connection()).fetch(sql) + # Annoyingly if there are 0 results we cannot use the equivalent + # of SQLite cursor.description to figure out what the columns + # should have been. I haven't found a workaround for that yet + # return Results(rows, truncated, cursor.description) + return PostgresqlResults(rows, truncated=False) + + async def table_counts(self, limit=10): + # Try to get counts for each table, TODO: $limit ms timeout for each count + counts = {} + for table in await self.table_names(): + table_count = await (await self.connection()).fetchval( + "select count(*) from {}".format(table) + ) + counts[table] = table_count + return counts + + async def table_exists(self, table): + raise NotImplementedError + + async def table_names(self): + results = await self.execute( + "select tablename from pg_catalog.pg_tables where schemaname not in ('pg_catalog', 'information_schema')" + ) + return [r[0] for r in results.rows] + + async def table_columns(self, table): + sql = """SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = '{}' + """.format( + table + ) + results = await self.execute(sql) + return [r[0] for r in results.rows] + + async def primary_keys(self, table): + sql = """ + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{}'::regclass + AND i.indisprimary;""".format( + table + ) + results = await self.execute(sql) + return [r[0] for r in results.rows] + + async def fts_table(self, table): + return None + # return await self.execute_against_connection_in_thread( + # lambda conn: detect_fts(conn, table) + # ) + + async def label_column_for_table(self, table): + explicit_label_column = self.ds.table_metadata(self.name, table).get( + "label_column" + ) + if explicit_label_column: + return explicit_label_column + # If a table has two columns, one of which is ID, then label_column is the other one + column_names = await self.execute_against_connection_in_thread( + lambda conn: table_columns(conn, table) + ) + # Is there a name or title column? + name_or_title = [c for c in column_names if c in ("name", "title")] + if name_or_title: + return name_or_title[0] + if ( + column_names + and len(column_names) == 2 + and ("id" in column_names or "pk" in column_names) + ): + return [c for c in column_names if c not in ("id", "pk")][0] + # Couldn't find a label: + return None + + async def foreign_keys_for_table(self, table): + # return await self.execute_against_connection_in_thread( + # lambda conn: get_outbound_foreign_keys(conn, table) + # ) + return [] + + async def hidden_table_names(self): + # Just the metadata.json ones: + hidden_tables = [] + db_metadata = self.ds.metadata(database=self.name) + if "tables" in db_metadata: + hidden_tables += [ + t + for t in db_metadata["tables"] + if db_metadata["tables"][t].get("hidden") + ] + return hidden_tables + + async def view_names(self): + # results = await self.execute("select name from sqlite_master where type='view'") + return [] + + async def get_all_foreign_keys(self): + # return await self.execute_against_connection_in_thread(get_all_foreign_keys) + return {t: [] for t in await self.table_names()} + + async def get_outbound_foreign_keys(self, table): + # return await self.execute_against_connection_in_thread( + # lambda conn: get_outbound_foreign_keys(conn, table) + # ) + return [] + + async def get_table_definition(self, table, type_="table"): + table_definition_rows = list( + await self.execute( + "select sql from sqlite_master where name = :n and type=:t", + {"n": table, "t": type_}, + ) + ) + if not table_definition_rows: + return None + bits = [table_definition_rows[0][0] + ";"] + # Add on any indexes + index_rows = list( + await self.ds.execute( + self.name, + "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", + {"n": table}, + ) + ) + for index_row in index_rows: + bits.append(index_row[0] + ";") + return "\n".join(bits) + + async def get_view_definition(self, view): + return await self.get_table_definition(view, "view") + + def __repr__(self): + tags = [] + if self.is_mutable: + tags.append("mutable") + if self.is_memory: + tags.append("memory") + if self.hash: + tags.append("hash={}".format(self.hash)) + if self.size is not None: + tags.append("size={}".format(self.size)) + tags_str = "" + if tags: + tags_str = " ({})".format(", ".join(tags)) + return "".format(self.name, tags_str) diff --git a/datasette/views/database.py b/datasette/views/database.py index 31d6af59..49c06469 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -2,6 +2,7 @@ import os from datasette.utils import to_css_class, validate_sql_select from datasette.utils.asgi import AsgiFileDownload +from datasette.postgresql_database import PostgresqlDatabase from .base import DatasetteError, DataView @@ -22,7 +23,12 @@ class DatabaseView(DataView): request, database, hash, sql, _size=_size, metadata=metadata ) - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) table_counts = await db.table_counts(5) views = await db.view_names() From b87130a036752821353fb251abfbaea2c70edb13 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 13 Feb 2020 12:43:06 -0800 Subject: [PATCH 2/2] Table page partially works on PostgreSQL, refs #670 --- datasette/database.py | 19 +++++++--- datasette/facets.py | 38 ++++++++------------ datasette/postgresql_database.py | 60 ++++++++++++++++++++----------- datasette/views/table.py | 62 +++++++++++++++++++++++--------- 4 files changed, 115 insertions(+), 64 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index 875dac20..bbace949 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -78,6 +78,12 @@ class Database: """Executes sql against db_name in a thread""" page_size = page_size or self.ds.page_size + # Where are we? + import io, traceback + + stored_stack = io.StringIO() + traceback.print_stack(file=stored_stack) + def sql_operation_in_thread(conn): time_limit_ms = self.ds.sql_time_limit_ms if custom_time_limit and custom_time_limit < time_limit_ms: @@ -114,10 +120,15 @@ class Database: else: return Results(rows, False, cursor.description) - with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_against_connection_in_thread( - sql_operation_in_thread - ) + try: + with trace("sql", database=self.name, sql=sql.strip(), params=params): + results = await self.execute_against_connection_in_thread( + sql_operation_in_thread + ) + except Exception as e: + print(e) + print(stored_stack.getvalue()) + raise return results @property diff --git a/datasette/facets.py b/datasette/facets.py index 18558754..031c164e 100644 --- a/datasette/facets.py +++ b/datasette/facets.py @@ -73,7 +73,7 @@ class Facet: self, ds, request, - database, + db, sql=None, table=None, params=None, @@ -83,7 +83,7 @@ class Facet: assert table or sql, "Must provide either table= or sql=" self.ds = ds self.request = request - self.database = database + self.db = db # For foreign key expansion. Can be None for e.g. canned SQL queries: self.table = table self.sql = sql or "select * from [{}]".format(table) @@ -113,17 +113,16 @@ class Facet: async def get_columns(self, sql, params=None): # Detect column names using the "limit 0" trick return ( - await self.ds.execute( - self.database, "select * from ({}) limit 0".format(sql), params or [] + await self.db.execute( + "select * from ({}) as derived limit 0".format(sql), params or [] ) ).columns async def get_row_count(self): if self.row_count is None: self.row_count = ( - await self.ds.execute( - self.database, - "select count(*) from ({})".format(self.sql), + await self.db.execute( + "select count(*) from ({}) as derived".format(self.sql), self.params, ) ).rows[0][0] @@ -153,8 +152,7 @@ class ColumnFacet(Facet): ) distinct_values = None try: - distinct_values = await self.ds.execute( - self.database, + distinct_values = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -203,8 +201,7 @@ class ColumnFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, @@ -225,8 +222,8 @@ class ColumnFacet(Facet): if self.table: # Attempt to expand foreign keys into labels values = [row["value"] for row in facet_rows] - expanded = await self.ds.expand_foreign_keys( - self.database, self.table, column, values + expanded = await self.db.expand_foreign_keys( + self.table, column, values ) else: expanded = {} @@ -285,8 +282,7 @@ class ArrayFacet(Facet): column=escape_sqlite(column), sql=self.sql ) try: - results = await self.ds.execute( - self.database, + results = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -298,8 +294,7 @@ class ArrayFacet(Facet): # Now sanity check that first 100 arrays contain only strings first_100 = [ v[0] - for v in await self.ds.execute( - self.database, + for v in await self.db.execute( "select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format( column=escape_sqlite(column), sql=self.sql ), @@ -349,8 +344,7 @@ class ArrayFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, @@ -416,8 +410,7 @@ class DateFacet(Facet): column=escape_sqlite(column), sql=self.sql ) try: - results = await self.ds.execute( - self.database, + results = await self.db.execute( suggested_facet_sql, self.params, truncate=False, @@ -462,8 +455,7 @@ class DateFacet(Facet): col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 ) try: - facet_rows_results = await self.ds.execute( - self.database, + facet_rows_results = await self.db.execute( facet_sql, self.params, truncate=False, diff --git a/datasette/postgresql_database.py b/datasette/postgresql_database.py index 7dc1f5b9..4fd2ca1b 100644 --- a/datasette/postgresql_database.py +++ b/datasette/postgresql_database.py @@ -7,6 +7,10 @@ class PostgresqlResults: self.rows = rows self.truncated = truncated + @property + def description(self): + return [[c] for c in self.columns] + @property def columns(self): try: @@ -24,6 +28,8 @@ class PostgresqlResults: class PostgresqlDatabase: size = 0 is_mutable = False + is_memory = False + hash = None def __init__(self, ds, name, dsn): self.ds = ds @@ -65,7 +71,7 @@ class PostgresqlDatabase: return counts async def table_exists(self, table): - raise NotImplementedError + return table in await self.table_names() async def table_names(self): results = await self.execute( @@ -159,29 +165,41 @@ class PostgresqlDatabase: return [] async def get_table_definition(self, table, type_="table"): - table_definition_rows = list( - await self.execute( - "select sql from sqlite_master where name = :n and type=:t", - {"n": table, "t": type_}, + sql = """ + SELECT + 'CREATE TABLE ' || relname || E'\n(\n' || + array_to_string( + array_agg( + ' ' || column_name || ' ' || type || ' '|| not_null ) - ) - if not table_definition_rows: - return None - bits = [table_definition_rows[0][0] + ";"] - # Add on any indexes - index_rows = list( - await self.ds.execute( - self.name, - "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", - {"n": table}, - ) - ) - for index_row in index_rows: - bits.append(index_row[0] + ";") - return "\n".join(bits) + , E',\n' + ) || E'\n);\n' + from + ( + SELECT + c.relname, a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) as type, + case + when a.attnotnull + then 'NOT NULL' + else 'NULL' + END as not_null + FROM pg_class c, + pg_attribute a, + pg_type t + WHERE c.relname = $1 + AND a.attnum > 0 + AND a.attrelid = c.oid + AND a.atttypid = t.oid + ORDER BY a.attnum + ) as tabledefinition + group by relname; + """ + return await (await self.connection()).fetchval(sql, table) async def get_view_definition(self, view): - return await self.get_table_definition(view, "view") + # return await self.get_table_definition(view, "view") + return [] def __repr__(self): tags = [] diff --git a/datasette/views/table.py b/datasette/views/table.py index 54839344..d77a1259 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -5,6 +5,7 @@ import json import jinja2 from datasette.plugins import pm +from datasette.postgresql_database import PostgresqlDatabase from datasette.utils import ( CustomRow, QueryInterrupted, @@ -64,7 +65,12 @@ class Row: class RowTableShared(DataView): async def sortable_columns_for_table(self, database, table, use_rowid): - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) table_metadata = self.ds.table_metadata(database, table) if "sortable_columns" in table_metadata: sortable_columns = set(table_metadata["sortable_columns"]) @@ -77,7 +83,12 @@ class RowTableShared(DataView): async def expandable_columns(self, database, table): # Returns list of (fk_dict, label_column-or-None) pairs for that table expandables = [] - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) for fk in await db.foreign_keys_for_table(table): label_column = await db.label_column_for_table(fk["other_table"]) expandables.append((fk, label_column)) @@ -87,7 +98,12 @@ class RowTableShared(DataView): self, database, table, description, rows, link_column=False, truncate_cells=0 ): "Returns columns, rows for specified table - including fancy foreign key treatment" - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) table_metadata = self.ds.table_metadata(database, table) sortable_columns = await self.sortable_columns_for_table(database, table, True) columns = [ @@ -228,7 +244,15 @@ class TableView(RowTableShared): editable=False, canned_query=table, ) - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) + + print("Here we go, db = ", db) + is_view = bool(await db.get_view_definition(table)) table_exists = bool(await db.table_exists(table)) if not is_view and not table_exists: @@ -533,17 +557,13 @@ class TableView(RowTableShared): if request.raw_args.get("_timelimit"): extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) - results = await self.ds.execute( - database, sql, params, truncate=True, **extra_args - ) + results = await db.execute(sql, params, truncate=True, **extra_args) # Number of filtered rows in whole set: filtered_table_rows_count = None if count_sql: try: - count_rows = list( - await self.ds.execute(database, count_sql, from_sql_params) - ) + count_rows = list(await db.execute(count_sql, from_sql_params)) filtered_table_rows_count = count_rows[0][0] except QueryInterrupted: pass @@ -566,7 +586,7 @@ class TableView(RowTableShared): klass( self.ds, request, - database, + db, sql=sql_no_limit, params=params, table=table, @@ -584,7 +604,7 @@ class TableView(RowTableShared): facets_timed_out.extend(instance_facets_timed_out) # Figure out columns and rows for the query - columns = [r[0] for r in results.description] + columns = list(results.rows[0].keys()) rows = list(results.rows) # Expand labeled columns if requested @@ -781,7 +801,12 @@ class RowView(RowTableShared): async def data(self, request, database, hash, table, pk_path, default_labels=False): pk_values = urlsafe_components(pk_path) - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) pks = await db.primary_keys(table) use_rowid = not pks select = "*" @@ -795,7 +820,7 @@ class RowView(RowTableShared): params = {} for i, pk_value in enumerate(pk_values): params["p{}".format(i)] = pk_value - results = await self.ds.execute(database, sql, params, truncate=True) + results = await db.execute(sql, params, truncate=True) columns = [r[0] for r in results.description] rows = list(results.rows) if not rows: @@ -860,7 +885,12 @@ class RowView(RowTableShared): async def foreign_key_tables(self, database, table, pk_values): if len(pk_values) != 1: return [] - db = self.ds.databases[database] + # db = self.ds.databases[database] + db = PostgresqlDatabase( + self.ds, + "simonwillisonblog", + "postgresql://postgres@localhost/simonwillisonblog", + ) all_foreign_keys = await db.get_all_foreign_keys() foreign_keys = all_foreign_keys[table]["incoming"] if len(foreign_keys) == 0: @@ -876,7 +906,7 @@ class RowView(RowTableShared): ] ) try: - rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]})) + rows = list(await db.execute(sql, {"id": pk_values[0]})) except sqlite3.OperationalError: # Almost certainly hit the timeout return []