mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Start of PostgreSQL prototype, refs #670
This prototype demonstrates the database page working against a hard-coded connection string to a PostgreSQL database. It lists tables and their columns and their row count.,
This commit is contained in:
parent
0091dfe3e5
commit
32a2f5793a
2 changed files with 206 additions and 1 deletions
199
datasette/postgresql_database.py
Normal file
199
datasette/postgresql_database.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
from .utils import Results
|
||||
import asyncpg
|
||||
|
||||
|
||||
class PostgresqlResults:
|
||||
def __init__(self, rows, truncated):
|
||||
self.rows = rows
|
||||
self.truncated = truncated
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
try:
|
||||
return list(self.rows[0].keys())
|
||||
except IndexError:
|
||||
return []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.rows)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.rows)
|
||||
|
||||
|
||||
class PostgresqlDatabase:
|
||||
size = 0
|
||||
is_mutable = False
|
||||
|
||||
def __init__(self, ds, name, dsn):
|
||||
self.ds = ds
|
||||
self.name = name
|
||||
self.dsn = dsn
|
||||
self._connection = None
|
||||
|
||||
async def connection(self):
|
||||
if self._connection is None:
|
||||
self._connection = await asyncpg.connect(self.dsn)
|
||||
return self._connection
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
sql,
|
||||
params=None,
|
||||
truncate=False,
|
||||
custom_time_limit=None,
|
||||
page_size=None,
|
||||
log_sql_errors=True,
|
||||
):
|
||||
"""Executes sql against db_name in a thread"""
|
||||
print(sql, params)
|
||||
rows = await (await self.connection()).fetch(sql)
|
||||
# Annoyingly if there are 0 results we cannot use the equivalent
|
||||
# of SQLite cursor.description to figure out what the columns
|
||||
# should have been. I haven't found a workaround for that yet
|
||||
# return Results(rows, truncated, cursor.description)
|
||||
return PostgresqlResults(rows, truncated=False)
|
||||
|
||||
async def table_counts(self, limit=10):
|
||||
# Try to get counts for each table, TODO: $limit ms timeout for each count
|
||||
counts = {}
|
||||
for table in await self.table_names():
|
||||
table_count = await (await self.connection()).fetchval(
|
||||
"select count(*) from {}".format(table)
|
||||
)
|
||||
counts[table] = table_count
|
||||
return counts
|
||||
|
||||
async def table_exists(self, table):
|
||||
raise NotImplementedError
|
||||
|
||||
async def table_names(self):
|
||||
results = await self.execute(
|
||||
"select tablename from pg_catalog.pg_tables where schemaname not in ('pg_catalog', 'information_schema')"
|
||||
)
|
||||
return [r[0] for r in results.rows]
|
||||
|
||||
async def table_columns(self, table):
|
||||
sql = """SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = '{}'
|
||||
""".format(
|
||||
table
|
||||
)
|
||||
results = await self.execute(sql)
|
||||
return [r[0] for r in results.rows]
|
||||
|
||||
async def primary_keys(self, table):
|
||||
sql = """
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
||||
AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = '{}'::regclass
|
||||
AND i.indisprimary;""".format(
|
||||
table
|
||||
)
|
||||
results = await self.execute(sql)
|
||||
return [r[0] for r in results.rows]
|
||||
|
||||
async def fts_table(self, table):
|
||||
return None
|
||||
# return await self.execute_against_connection_in_thread(
|
||||
# lambda conn: detect_fts(conn, table)
|
||||
# )
|
||||
|
||||
async def label_column_for_table(self, table):
|
||||
explicit_label_column = self.ds.table_metadata(self.name, table).get(
|
||||
"label_column"
|
||||
)
|
||||
if explicit_label_column:
|
||||
return explicit_label_column
|
||||
# If a table has two columns, one of which is ID, then label_column is the other one
|
||||
column_names = await self.execute_against_connection_in_thread(
|
||||
lambda conn: table_columns(conn, table)
|
||||
)
|
||||
# Is there a name or title column?
|
||||
name_or_title = [c for c in column_names if c in ("name", "title")]
|
||||
if name_or_title:
|
||||
return name_or_title[0]
|
||||
if (
|
||||
column_names
|
||||
and len(column_names) == 2
|
||||
and ("id" in column_names or "pk" in column_names)
|
||||
):
|
||||
return [c for c in column_names if c not in ("id", "pk")][0]
|
||||
# Couldn't find a label:
|
||||
return None
|
||||
|
||||
async def foreign_keys_for_table(self, table):
|
||||
# return await self.execute_against_connection_in_thread(
|
||||
# lambda conn: get_outbound_foreign_keys(conn, table)
|
||||
# )
|
||||
return []
|
||||
|
||||
async def hidden_table_names(self):
|
||||
# Just the metadata.json ones:
|
||||
hidden_tables = []
|
||||
db_metadata = self.ds.metadata(database=self.name)
|
||||
if "tables" in db_metadata:
|
||||
hidden_tables += [
|
||||
t
|
||||
for t in db_metadata["tables"]
|
||||
if db_metadata["tables"][t].get("hidden")
|
||||
]
|
||||
return hidden_tables
|
||||
|
||||
async def view_names(self):
|
||||
# results = await self.execute("select name from sqlite_master where type='view'")
|
||||
return []
|
||||
|
||||
async def get_all_foreign_keys(self):
|
||||
# return await self.execute_against_connection_in_thread(get_all_foreign_keys)
|
||||
return {t: [] for t in await self.table_names()}
|
||||
|
||||
async def get_outbound_foreign_keys(self, table):
|
||||
# return await self.execute_against_connection_in_thread(
|
||||
# lambda conn: get_outbound_foreign_keys(conn, table)
|
||||
# )
|
||||
return []
|
||||
|
||||
async def get_table_definition(self, table, type_="table"):
|
||||
table_definition_rows = list(
|
||||
await self.execute(
|
||||
"select sql from sqlite_master where name = :n and type=:t",
|
||||
{"n": table, "t": type_},
|
||||
)
|
||||
)
|
||||
if not table_definition_rows:
|
||||
return None
|
||||
bits = [table_definition_rows[0][0] + ";"]
|
||||
# Add on any indexes
|
||||
index_rows = list(
|
||||
await self.ds.execute(
|
||||
self.name,
|
||||
"select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null",
|
||||
{"n": table},
|
||||
)
|
||||
)
|
||||
for index_row in index_rows:
|
||||
bits.append(index_row[0] + ";")
|
||||
return "\n".join(bits)
|
||||
|
||||
async def get_view_definition(self, view):
|
||||
return await self.get_table_definition(view, "view")
|
||||
|
||||
def __repr__(self):
|
||||
tags = []
|
||||
if self.is_mutable:
|
||||
tags.append("mutable")
|
||||
if self.is_memory:
|
||||
tags.append("memory")
|
||||
if self.hash:
|
||||
tags.append("hash={}".format(self.hash))
|
||||
if self.size is not None:
|
||||
tags.append("size={}".format(self.size))
|
||||
tags_str = ""
|
||||
if tags:
|
||||
tags_str = " ({})".format(", ".join(tags))
|
||||
return "<Database: {}{}>".format(self.name, tags_str)
|
||||
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
from datasette.utils import to_css_class, validate_sql_select
|
||||
from datasette.utils.asgi import AsgiFileDownload
|
||||
from datasette.postgresql_database import PostgresqlDatabase
|
||||
|
||||
from .base import DatasetteError, DataView
|
||||
|
||||
|
|
@ -22,7 +23,12 @@ class DatabaseView(DataView):
|
|||
request, database, hash, sql, _size=_size, metadata=metadata
|
||||
)
|
||||
|
||||
db = self.ds.databases[database]
|
||||
# db = self.ds.databases[database]
|
||||
db = PostgresqlDatabase(
|
||||
self.ds,
|
||||
"simonwillisonblog",
|
||||
"postgresql://postgres@localhost/simonwillisonblog",
|
||||
)
|
||||
|
||||
table_counts = await db.table_counts(5)
|
||||
views = await db.view_names()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue