From 026c84db30bd0a75ecde146a80a5d142078dc299 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 26 May 2019 21:56:43 -0700 Subject: [PATCH] Refactor Datasette methods to ConnectedDatabase Refs #487 --- datasette/app.py | 113 +++++++++++++++++++----------------- datasette/views/base.py | 5 +- datasette/views/database.py | 2 +- datasette/views/index.py | 2 +- datasette/views/table.py | 28 ++++----- 5 files changed, 78 insertions(+), 72 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 4f57db2d..e39c4097 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -197,12 +197,53 @@ class ConnectedDatabase: else: return Path(self.path).stem + async def table_exists(self, table): + results = await self.ds.execute( + self.name, + "select 1 from sqlite_master where type='table' and name=?", + params=(table,), + ) + return bool(results.rows) + async def table_names(self): results = await self.ds.execute( self.name, "select name from sqlite_master where type='table'" ) return [r[0] for r in results.rows] + async def table_columns(self, table): + return await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: table_columns(conn, table) + ) + + async def label_column_for_table(self, table): + explicit_label_column = self.ds.table_metadata(self.name, table).get( + "label_column" + ) + if explicit_label_column: + return explicit_label_column + # If a table has two columns, one of which is ID, then label_column is the other one + column_names = await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: table_columns(conn, table) + ) + # Is there a name or title column? + name_or_title = [c for c in column_names if c in ("name", "title")] + if name_or_title: + return name_or_title[0] + if ( + column_names + and len(column_names) == 2 + and ("id" in column_names or "pk" in column_names) + ): + return [c for c in column_names if c not in ("id", "pk")][0] + # Couldn't find a label: + return None + + async def foreign_keys_for_table(self, table): + return await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: get_outbound_foreign_keys(conn, table) + ) + async def hidden_table_names(self): # Mark tables 'hidden' if they relate to FTS virtual tables hidden_tables = [ @@ -275,6 +316,21 @@ class ConnectedDatabase: self.name, get_all_foreign_keys ) + async def get_table_definition(self, table, type_="table"): + table_definition_rows = list( + await self.ds.execute( + self.name, + "select sql from sqlite_master where name = :n and type=:t", + {"n": table, "t": type_}, + ) + ) + if not table_definition_rows: + return None + return table_definition_rows[0][0] + + async def get_view_definition(self, view): + return await self.get_table_definition(view, "view") + def __repr__(self): tags = [] if self.is_mutable: @@ -451,21 +507,6 @@ class Datasette: query["name"] = query_name return query - async def get_table_definition(self, database_name, table, type_="table"): - table_definition_rows = list( - await self.execute( - database_name, - "select sql from sqlite_master where name = :n and type=:t", - {"n": table, "t": type_}, - ) - ) - if not table_definition_rows: - return None - return table_definition_rows[0][0] - - def get_view_definition(self, database_name, view): - return self.get_table_definition(database_name, view, "view") - def update_with_inherited_metadata(self, metadata): # Fills in source/license with defaults, if available metadata.update( @@ -494,18 +535,11 @@ class Datasette: # pylint: disable=no-member pm.hook.prepare_connection(conn=conn) - async def table_exists(self, database, table): - results = await self.execute( - database, - "select 1 from sqlite_master where type='table' and name=?", - params=(table,), - ) - return bool(results.rows) - async def expand_foreign_keys(self, database, table, column, values): "Returns dict mapping (column, value) -> label" labeled_fks = {} - foreign_keys = await self.foreign_keys_for_table(database, table) + db = self.databases[database] + foreign_keys = await db.foreign_keys_for_table(table) # Find the foreign_key for this column try: fk = [ @@ -515,7 +549,7 @@ class Datasette: ][0] except IndexError: return {} - label_column = await self.label_column_for_table(database, fk["other_table"]) + label_column = await db.label_column_for_table(fk["other_table"]) if not label_column: return {(fk["column"], value): str(value) for value in values} labeled_fks = {} @@ -631,35 +665,6 @@ class Datasette: .get(table, {}) ) - async def table_columns(self, db_name, table): - return await self.execute_against_connection_in_thread( - db_name, lambda conn: table_columns(conn, table) - ) - - async def foreign_keys_for_table(self, database, table): - return await self.execute_against_connection_in_thread( - database, lambda conn: get_outbound_foreign_keys(conn, table) - ) - - async def label_column_for_table(self, db_name, table): - explicit_label_column = self.table_metadata(db_name, table).get("label_column") - if explicit_label_column: - return explicit_label_column - # If a table has two columns, one of which is ID, then label_column is the other one - column_names = await self.table_columns(db_name, table) - # Is there a name or title column? - name_or_title = [c for c in column_names if c in ("name", "title")] - if name_or_title: - return name_or_title[0] - if ( - column_names - and len(column_names) == 2 - and ("id" in column_names or "pk" in column_names) - ): - return [c for c in column_names if c not in ("id", "pk")][0] - # Couldn't find a label: - return None - async def execute_against_connection_in_thread(self, db_name, fn): def in_thread(): conn = getattr(connections, db_name, None) diff --git a/datasette/views/base.py b/datasette/views/base.py index f81617a0..4294784d 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -183,7 +183,7 @@ class BaseView(RenderMixin): if "table_and_format" in kwargs: async def async_table_exists(t): - return await self.ds.table_exists(name, t) + return await db.table_exists(t) table, _format = await resolve_table_and_format( table_and_format=urllib.parse.unquote_plus( @@ -328,9 +328,10 @@ class BaseView(RenderMixin): if not _format: _format = (args.pop("as_format", None) or "").lstrip(".") if "table_and_format" in args: + db = self.ds.databases[database] async def async_table_exists(t): - return await self.ds.table_exists(database, t) + return await db.table_exists(t) table, _ext_format = await resolve_table_and_format( table_and_format=urllib.parse.unquote_plus(args["table_and_format"]), diff --git a/datasette/views/database.py b/datasette/views/database.py index c5c00bf4..d7b07762 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -36,7 +36,7 @@ class DatabaseView(BaseView): tables = [] for table in table_counts: - table_columns = await self.ds.table_columns(database, table) + table_columns = await db.table_columns(table) tables.append( { "name": table, diff --git a/datasette/views/index.py b/datasette/views/index.py index d3d82bae..276ea1cc 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -42,7 +42,7 @@ class IndexView(RenderMixin): table_counts = {} tables = {} for table in table_names: - table_columns = await self.ds.table_columns(name, table) + table_columns = await db.table_columns(table) tables[table] = { "name": table, "columns": table_columns, diff --git a/datasette/views/table.py b/datasette/views/table.py index 97723a50..23777ff8 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -41,11 +41,12 @@ LINK_WITH_VALUE = '{id}' class RowTableShared(BaseView): async def sortable_columns_for_table(self, database, table, use_rowid): + db = self.ds.databases[database] table_metadata = self.ds.table_metadata(database, table) if "sortable_columns" in table_metadata: sortable_columns = set(table_metadata["sortable_columns"]) else: - sortable_columns = set(await self.ds.table_columns(database, table)) + sortable_columns = set(await db.table_columns(table)) if use_rowid: sortable_columns.add("rowid") return sortable_columns @@ -53,10 +54,9 @@ class RowTableShared(BaseView): async def expandable_columns(self, database, table): # Returns list of (fk_dict, label_column-or-None) pairs for that table expandables = [] - for fk in await self.ds.foreign_keys_for_table(database, table): - label_column = await self.ds.label_column_for_table( - database, fk["other_table"] - ) + db = self.ds.databases[database] + for fk in await db.foreign_keys_for_table(table): + label_column = await db.label_column_for_table(fk["other_table"]) expandables.append((fk, label_column)) return expandables @@ -64,6 +64,7 @@ class RowTableShared(BaseView): self, database, table, description, rows, link_column=False, truncate_cells=0 ): "Returns columns, rows for specified table - including fancy foreign key treatment" + db = self.ds.databases[database] table_metadata = self.ds.table_metadata(database, table) sortable_columns = await self.sortable_columns_for_table(database, table, True) columns = [ @@ -74,7 +75,7 @@ class RowTableShared(BaseView): ) column_to_foreign_key_table = { fk["column"]: fk["other_table"] - for fk in await self.ds.foreign_keys_for_table(database, table) + for fk in await db.foreign_keys_for_table(table) } cell_rows = [] @@ -206,11 +207,12 @@ class TableView(RowTableShared): editable=False, canned_query=table, ) - - is_view = bool(await self.ds.get_view_definition(database, table)) - table_exists = bool(await self.ds.table_exists(database, table)) + db = self.ds.databases[database] + is_view = bool(await db.get_view_definition(table)) + table_exists = bool(await db.table_exists(table)) if not is_view and not table_exists: raise NotFound("Table not found: {}".format(table)) + pks = await self.ds.execute_against_connection_in_thread( database, lambda conn: detect_primary_keys(conn, table) ) @@ -352,9 +354,7 @@ class TableView(RowTableShared): # More complex: search against specific columns for i, (key, search_text) in enumerate(search_args.items()): search_col = key.split("_search_", 1)[1] - if search_col not in await self.ds.table_columns( - database, fts_table - ): + if search_col not in await db.table_columns(fts_table): raise DatasetteError("Cannot search by that column", status=400) where_clauses.append( @@ -739,8 +739,8 @@ class TableView(RowTableShared): "_rows_and_columns.html", ], "metadata": metadata, - "view_definition": await self.ds.get_view_definition(database, table), - "table_definition": await self.ds.get_table_definition(database, table), + "view_definition": await db.get_view_definition(table), + "table_definition": await db.get_table_definition(table), } return (