Compare commits

...

2 commits

Author SHA1 Message Date
Simon Willison
b87130a036 Table page partially works on PostgreSQL, refs #670 2020-02-13 12:43:06 -08:00
Simon Willison
32a2f5793a Start of PostgreSQL prototype, refs #670
This prototype demonstrates the database page working against a
hard-coded connection string to a PostgreSQL database. It lists
tables and their columns and their row count.,
2020-02-13 12:15:41 -08:00
5 changed files with 300 additions and 44 deletions

View file

@ -78,6 +78,12 @@ class Database:
"""Executes sql against db_name in a thread"""
page_size = page_size or self.ds.page_size
# Where are we?
import io, traceback
stored_stack = io.StringIO()
traceback.print_stack(file=stored_stack)
def sql_operation_in_thread(conn):
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
@ -114,10 +120,15 @@ class Database:
else:
return Results(rows, False, cursor.description)
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
sql_operation_in_thread
)
try:
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_against_connection_in_thread(
sql_operation_in_thread
)
except Exception as e:
print(e)
print(stored_stack.getvalue())
raise
return results
@property

View file

@ -73,7 +73,7 @@ class Facet:
self,
ds,
request,
database,
db,
sql=None,
table=None,
params=None,
@ -83,7 +83,7 @@ class Facet:
assert table or sql, "Must provide either table= or sql="
self.ds = ds
self.request = request
self.database = database
self.db = db
# For foreign key expansion. Can be None for e.g. canned SQL queries:
self.table = table
self.sql = sql or "select * from [{}]".format(table)
@ -113,17 +113,16 @@ class Facet:
async def get_columns(self, sql, params=None):
# Detect column names using the "limit 0" trick
return (
await self.ds.execute(
self.database, "select * from ({}) limit 0".format(sql), params or []
await self.db.execute(
"select * from ({}) as derived limit 0".format(sql), params or []
)
).columns
async def get_row_count(self):
if self.row_count is None:
self.row_count = (
await self.ds.execute(
self.database,
"select count(*) from ({})".format(self.sql),
await self.db.execute(
"select count(*) from ({}) as derived".format(self.sql),
self.params,
)
).rows[0][0]
@ -153,8 +152,7 @@ class ColumnFacet(Facet):
)
distinct_values = None
try:
distinct_values = await self.ds.execute(
self.database,
distinct_values = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -203,8 +201,7 @@ class ColumnFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,
@ -225,8 +222,8 @@ class ColumnFacet(Facet):
if self.table:
# Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows]
expanded = await self.ds.expand_foreign_keys(
self.database, self.table, column, values
expanded = await self.db.expand_foreign_keys(
self.table, column, values
)
else:
expanded = {}
@ -285,8 +282,7 @@ class ArrayFacet(Facet):
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database,
results = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -298,8 +294,7 @@ class ArrayFacet(Facet):
# Now sanity check that first 100 arrays contain only strings
first_100 = [
v[0]
for v in await self.ds.execute(
self.database,
for v in await self.db.execute(
"select {column} from ({sql}) where {column} is not null and json_array_length({column}) > 0 limit 100".format(
column=escape_sqlite(column), sql=self.sql
),
@ -349,8 +344,7 @@ class ArrayFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,
@ -416,8 +410,7 @@ class DateFacet(Facet):
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database,
results = await self.db.execute(
suggested_facet_sql,
self.params,
truncate=False,
@ -462,8 +455,7 @@ class DateFacet(Facet):
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database,
facet_rows_results = await self.db.execute(
facet_sql,
self.params,
truncate=False,

View file

@ -0,0 +1,217 @@
from .utils import Results
import asyncpg
class PostgresqlResults:
def __init__(self, rows, truncated):
self.rows = rows
self.truncated = truncated
@property
def description(self):
return [[c] for c in self.columns]
@property
def columns(self):
try:
return list(self.rows[0].keys())
except IndexError:
return []
def __iter__(self):
return iter(self.rows)
def __len__(self):
return len(self.rows)
class PostgresqlDatabase:
size = 0
is_mutable = False
is_memory = False
hash = None
def __init__(self, ds, name, dsn):
self.ds = ds
self.name = name
self.dsn = dsn
self._connection = None
async def connection(self):
if self._connection is None:
self._connection = await asyncpg.connect(self.dsn)
return self._connection
async def execute(
self,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
log_sql_errors=True,
):
"""Executes sql against db_name in a thread"""
print(sql, params)
rows = await (await self.connection()).fetch(sql)
# Annoyingly if there are 0 results we cannot use the equivalent
# of SQLite cursor.description to figure out what the columns
# should have been. I haven't found a workaround for that yet
# return Results(rows, truncated, cursor.description)
return PostgresqlResults(rows, truncated=False)
async def table_counts(self, limit=10):
# Try to get counts for each table, TODO: $limit ms timeout for each count
counts = {}
for table in await self.table_names():
table_count = await (await self.connection()).fetchval(
"select count(*) from {}".format(table)
)
counts[table] = table_count
return counts
async def table_exists(self, table):
return table in await self.table_names()
async def table_names(self):
results = await self.execute(
"select tablename from pg_catalog.pg_tables where schemaname not in ('pg_catalog', 'information_schema')"
)
return [r[0] for r in results.rows]
async def table_columns(self, table):
sql = """SELECT column_name
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = '{}'
""".format(
table
)
results = await self.execute(sql)
return [r[0] for r in results.rows]
async def primary_keys(self, table):
sql = """
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = '{}'::regclass
AND i.indisprimary;""".format(
table
)
results = await self.execute(sql)
return [r[0] for r in results.rows]
async def fts_table(self, table):
return None
# return await self.execute_against_connection_in_thread(
# lambda conn: detect_fts(conn, table)
# )
async def label_column_for_table(self, table):
explicit_label_column = self.ds.table_metadata(self.name, table).get(
"label_column"
)
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.execute_against_connection_in_thread(
lambda conn: table_columns(conn, table)
)
# Is there a name or title column?
name_or_title = [c for c in column_names if c in ("name", "title")]
if name_or_title:
return name_or_title[0]
if (
column_names
and len(column_names) == 2
and ("id" in column_names or "pk" in column_names)
):
return [c for c in column_names if c not in ("id", "pk")][0]
# Couldn't find a label:
return None
async def foreign_keys_for_table(self, table):
# return await self.execute_against_connection_in_thread(
# lambda conn: get_outbound_foreign_keys(conn, table)
# )
return []
async def hidden_table_names(self):
# Just the metadata.json ones:
hidden_tables = []
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
return hidden_tables
async def view_names(self):
# results = await self.execute("select name from sqlite_master where type='view'")
return []
async def get_all_foreign_keys(self):
# return await self.execute_against_connection_in_thread(get_all_foreign_keys)
return {t: [] for t in await self.table_names()}
async def get_outbound_foreign_keys(self, table):
# return await self.execute_against_connection_in_thread(
# lambda conn: get_outbound_foreign_keys(conn, table)
# )
return []
async def get_table_definition(self, table, type_="table"):
sql = """
SELECT
'CREATE TABLE ' || relname || E'\n(\n' ||
array_to_string(
array_agg(
' ' || column_name || ' ' || type || ' '|| not_null
)
, E',\n'
) || E'\n);\n'
from
(
SELECT
c.relname, a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as type,
case
when a.attnotnull
then 'NOT NULL'
else 'NULL'
END as not_null
FROM pg_class c,
pg_attribute a,
pg_type t
WHERE c.relname = $1
AND a.attnum > 0
AND a.attrelid = c.oid
AND a.atttypid = t.oid
ORDER BY a.attnum
) as tabledefinition
group by relname;
"""
return await (await self.connection()).fetchval(sql, table)
async def get_view_definition(self, view):
# return await self.get_table_definition(view, "view")
return []
def __repr__(self):
tags = []
if self.is_mutable:
tags.append("mutable")
if self.is_memory:
tags.append("memory")
if self.hash:
tags.append("hash={}".format(self.hash))
if self.size is not None:
tags.append("size={}".format(self.size))
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<Database: {}{}>".format(self.name, tags_str)

View file

@ -2,6 +2,7 @@ import os
from datasette.utils import to_css_class, validate_sql_select
from datasette.utils.asgi import AsgiFileDownload
from datasette.postgresql_database import PostgresqlDatabase
from .base import DatasetteError, DataView
@ -22,7 +23,12 @@ class DatabaseView(DataView):
request, database, hash, sql, _size=_size, metadata=metadata
)
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_counts = await db.table_counts(5)
views = await db.view_names()

View file

@ -5,6 +5,7 @@ import json
import jinja2
from datasette.plugins import pm
from datasette.postgresql_database import PostgresqlDatabase
from datasette.utils import (
CustomRow,
QueryInterrupted,
@ -64,7 +65,12 @@ class Row:
class RowTableShared(DataView):
async def sortable_columns_for_table(self, database, table, use_rowid):
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata:
sortable_columns = set(table_metadata["sortable_columns"])
@ -77,7 +83,12 @@ class RowTableShared(DataView):
async def expandable_columns(self, database, table):
# Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = []
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
for fk in await db.foreign_keys_for_table(table):
label_column = await db.label_column_for_table(fk["other_table"])
expandables.append((fk, label_column))
@ -87,7 +98,12 @@ class RowTableShared(DataView):
self, database, table, description, rows, link_column=False, truncate_cells=0
):
"Returns columns, rows for specified table - including fancy foreign key treatment"
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
table_metadata = self.ds.table_metadata(database, table)
sortable_columns = await self.sortable_columns_for_table(database, table, True)
columns = [
@ -228,7 +244,15 @@ class TableView(RowTableShared):
editable=False,
canned_query=table,
)
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
print("Here we go, db = ", db)
is_view = bool(await db.get_view_definition(table))
table_exists = bool(await db.table_exists(table))
if not is_view and not table_exists:
@ -533,17 +557,13 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
results = await self.ds.execute(
database, sql, params, truncate=True, **extra_args
)
results = await db.execute(sql, params, truncate=True, **extra_args)
# Number of filtered rows in whole set:
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(
await self.ds.execute(database, count_sql, from_sql_params)
)
count_rows = list(await db.execute(count_sql, from_sql_params))
filtered_table_rows_count = count_rows[0][0]
except QueryInterrupted:
pass
@ -566,7 +586,7 @@ class TableView(RowTableShared):
klass(
self.ds,
request,
database,
db,
sql=sql_no_limit,
params=params,
table=table,
@ -584,7 +604,7 @@ class TableView(RowTableShared):
facets_timed_out.extend(instance_facets_timed_out)
# Figure out columns and rows for the query
columns = [r[0] for r in results.description]
columns = list(results.rows[0].keys())
rows = list(results.rows)
# Expand labeled columns if requested
@ -781,7 +801,12 @@ class RowView(RowTableShared):
async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
pks = await db.primary_keys(table)
use_rowid = not pks
select = "*"
@ -795,7 +820,7 @@ class RowView(RowTableShared):
params = {}
for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value
results = await self.ds.execute(database, sql, params, truncate=True)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
@ -860,7 +885,12 @@ class RowView(RowTableShared):
async def foreign_key_tables(self, database, table, pk_values):
if len(pk_values) != 1:
return []
db = self.ds.databases[database]
# db = self.ds.databases[database]
db = PostgresqlDatabase(
self.ds,
"simonwillisonblog",
"postgresql://postgres@localhost/simonwillisonblog",
)
all_foreign_keys = await db.get_all_foreign_keys()
foreign_keys = all_foreign_keys[table]["incoming"]
if len(foreign_keys) == 0:
@ -876,7 +906,7 @@ class RowView(RowTableShared):
]
)
try:
rows = list(await self.ds.execute(database, sql, {"id": pk_values[0]}))
rows = list(await db.execute(sql, {"id": pk_values[0]}))
except sqlite3.OperationalError:
# Almost certainly hit the timeout
return []