From 6569287d9001fb7e37f99b4bc5eddb78d047fa6f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 26 May 2019 22:07:27 -0700 Subject: [PATCH] Refactored ConnectedDatabase to datasette/database.py Closes #487 --- datasette/app.py | 226 +----------------------------------- datasette/database.py | 224 +++++++++++++++++++++++++++++++++++ datasette/views/database.py | 1 - 3 files changed, 228 insertions(+), 223 deletions(-) create mode 100644 datasette/database.py diff --git a/datasette/app.py b/datasette/app.py index e39c4097..93527850 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,5 +1,4 @@ import asyncio -import click import collections import hashlib import json @@ -12,6 +11,7 @@ import urllib.parse from concurrent import futures from pathlib import Path +import click from markupsafe import Markup from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from sanic import Sanic, response @@ -23,23 +23,19 @@ from .views.index import IndexView from .views.special import JsonDataView from .views.table import RowView, TableView from .renderer import json_renderer +from .database import Database from .utils import ( InterruptedError, Results, - detect_spatialite, escape_css_string, escape_sqlite, - get_all_foreign_keys, - get_outbound_foreign_keys, get_plugins, module_from_path, sqlite3, sqlite_timelimit, - table_columns, to_css_class, ) -from .inspect import inspect_hash, inspect_views, inspect_tables from .tracer import capture_traces, trace from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ @@ -134,219 +130,6 @@ async def favicon(request): return response.text("") -class ConnectedDatabase: - def __init__(self, ds, path=None, is_mutable=False, is_memory=False): - self.ds = ds - self.path = path - self.is_mutable = is_mutable - self.is_memory = is_memory - self.hash = None - self.cached_size = None - self.cached_table_counts = None - if not self.is_mutable: - p = Path(path) - self.hash = inspect_hash(p) - self.cached_size = p.stat().st_size - # Maybe use self.ds.inspect_data to populate cached_table_counts - if self.ds.inspect_data and self.ds.inspect_data.get(self.name): - self.cached_table_counts = { - key: value["count"] - for key, value in self.ds.inspect_data[self.name]["tables"].items() - } - - @property - def size(self): - if self.is_memory: - return 0 - if self.cached_size is not None: - return self.cached_size - else: - return Path(self.path).stat().st_size - - async def table_counts(self, limit=10): - if not self.is_mutable and self.cached_table_counts is not None: - return self.cached_table_counts - # Try to get counts for each table, $limit timeout for each count - counts = {} - for table in await self.table_names(): - try: - table_count = ( - await self.ds.execute( - self.name, - "select count(*) from [{}]".format(table), - custom_time_limit=limit, - ) - ).rows[0][0] - counts[table] = table_count - # In some cases I saw "SQL Logic Error" here in addition to - # InterruptedError - so we catch that too: - except (InterruptedError, sqlite3.OperationalError): - counts[table] = None - if not self.is_mutable: - self.cached_table_counts = counts - return counts - - @property - def mtime_ns(self): - return Path(self.path).stat().st_mtime_ns - - @property - def name(self): - if self.is_memory: - return ":memory:" - else: - return Path(self.path).stem - - async def table_exists(self, table): - results = await self.ds.execute( - self.name, - "select 1 from sqlite_master where type='table' and name=?", - params=(table,), - ) - return bool(results.rows) - - async def table_names(self): - results = await self.ds.execute( - self.name, "select name from sqlite_master where type='table'" - ) - return [r[0] for r in results.rows] - - async def table_columns(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: table_columns(conn, table) - ) - - async def label_column_for_table(self, table): - explicit_label_column = self.ds.table_metadata(self.name, table).get( - "label_column" - ) - if explicit_label_column: - return explicit_label_column - # If a table has two columns, one of which is ID, then label_column is the other one - column_names = await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: table_columns(conn, table) - ) - # Is there a name or title column? - name_or_title = [c for c in column_names if c in ("name", "title")] - if name_or_title: - return name_or_title[0] - if ( - column_names - and len(column_names) == 2 - and ("id" in column_names or "pk" in column_names) - ): - return [c for c in column_names if c not in ("id", "pk")][0] - # Couldn't find a label: - return None - - async def foreign_keys_for_table(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: get_outbound_foreign_keys(conn, table) - ) - - async def hidden_table_names(self): - # Mark tables 'hidden' if they relate to FTS virtual tables - hidden_tables = [ - r[0] - for r in ( - await self.ds.execute( - self.name, - """ - select name from sqlite_master - where rootpage = 0 - and sql like '%VIRTUAL TABLE%USING FTS%' - """, - ) - ).rows - ] - has_spatialite = await self.ds.execute_against_connection_in_thread( - self.name, detect_spatialite - ) - if has_spatialite: - # Also hide Spatialite internal tables - hidden_tables += [ - "ElementaryGeometries", - "SpatialIndex", - "geometry_columns", - "spatial_ref_sys", - "spatialite_history", - "sql_statements_log", - "sqlite_sequence", - "views_geometry_columns", - "virts_geometry_columns", - ] + [ - r[0] - for r in ( - await self.ds.execute( - self.name, - """ - select name from sqlite_master - where name like "idx_%" - and type = "table" - """, - ) - ).rows - ] - # Add any from metadata.json - db_metadata = self.ds.metadata(database=self.name) - if "tables" in db_metadata: - hidden_tables += [ - t - for t in db_metadata["tables"] - if db_metadata["tables"][t].get("hidden") - ] - # Also mark as hidden any tables which start with the name of a hidden table - # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden - for table_name in await self.table_names(): - for hidden_table in hidden_tables[:]: - if table_name.startswith(hidden_table): - hidden_tables.append(table_name) - continue - - return hidden_tables - - async def view_names(self): - results = await self.ds.execute( - self.name, "select name from sqlite_master where type='view'" - ) - return [r[0] for r in results.rows] - - async def get_all_foreign_keys(self): - return await self.ds.execute_against_connection_in_thread( - self.name, get_all_foreign_keys - ) - - async def get_table_definition(self, table, type_="table"): - table_definition_rows = list( - await self.ds.execute( - self.name, - "select sql from sqlite_master where name = :n and type=:t", - {"n": table, "t": type_}, - ) - ) - if not table_definition_rows: - return None - return table_definition_rows[0][0] - - async def get_view_definition(self, view): - return await self.get_table_definition(view, "view") - - def __repr__(self): - tags = [] - if self.is_mutable: - tags.append("mutable") - if self.is_memory: - tags.append("memory") - if self.hash: - tags.append("hash={}".format(self.hash)) - if self.size is not None: - tags.append("size={}".format(self.size)) - tags_str = "" - if tags: - tags_str = " ({})".format(", ".join(tags)) - return "".format(self.name, tags_str) - - class Datasette: def __init__( self, @@ -380,9 +163,7 @@ class Datasette: path = None is_memory = True is_mutable = path not in self.immutables - db = ConnectedDatabase( - self, path, is_mutable=is_mutable, is_memory=is_memory - ) + db = Database(self, path, is_mutable=is_mutable, is_memory=is_memory) if db.name in self.databases: raise Exception("Multiple files with same stem: {}".format(db.name)) self.databases[db.name] = db @@ -751,6 +532,7 @@ class Datasette: # Hooks hook_renderers = [] + # pylint: disable=no-member for hook in pm.hook.register_output_renderer(datasette=self): if type(hook) == list: hook_renderers += hook diff --git a/datasette/database.py b/datasette/database.py new file mode 100644 index 00000000..b25765bc --- /dev/null +++ b/datasette/database.py @@ -0,0 +1,224 @@ +from pathlib import Path + +from .utils import ( + InterruptedError, + detect_spatialite, + get_all_foreign_keys, + get_outbound_foreign_keys, + sqlite3, + table_columns, +) +from .inspect import inspect_hash + + +class Database: + def __init__(self, ds, path=None, is_mutable=False, is_memory=False): + self.ds = ds + self.path = path + self.is_mutable = is_mutable + self.is_memory = is_memory + self.hash = None + self.cached_size = None + self.cached_table_counts = None + if not self.is_mutable: + p = Path(path) + self.hash = inspect_hash(p) + self.cached_size = p.stat().st_size + # Maybe use self.ds.inspect_data to populate cached_table_counts + if self.ds.inspect_data and self.ds.inspect_data.get(self.name): + self.cached_table_counts = { + key: value["count"] + for key, value in self.ds.inspect_data[self.name]["tables"].items() + } + + @property + def size(self): + if self.is_memory: + return 0 + if self.cached_size is not None: + return self.cached_size + else: + return Path(self.path).stat().st_size + + async def table_counts(self, limit=10): + if not self.is_mutable and self.cached_table_counts is not None: + return self.cached_table_counts + # Try to get counts for each table, $limit timeout for each count + counts = {} + for table in await self.table_names(): + try: + table_count = ( + await self.ds.execute( + self.name, + "select count(*) from [{}]".format(table), + custom_time_limit=limit, + ) + ).rows[0][0] + counts[table] = table_count + # In some cases I saw "SQL Logic Error" here in addition to + # InterruptedError - so we catch that too: + except (InterruptedError, sqlite3.OperationalError): + counts[table] = None + if not self.is_mutable: + self.cached_table_counts = counts + return counts + + @property + def mtime_ns(self): + return Path(self.path).stat().st_mtime_ns + + @property + def name(self): + if self.is_memory: + return ":memory:" + else: + return Path(self.path).stem + + async def table_exists(self, table): + results = await self.ds.execute( + self.name, + "select 1 from sqlite_master where type='table' and name=?", + params=(table,), + ) + return bool(results.rows) + + async def table_names(self): + results = await self.ds.execute( + self.name, "select name from sqlite_master where type='table'" + ) + return [r[0] for r in results.rows] + + async def table_columns(self, table): + return await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: table_columns(conn, table) + ) + + async def label_column_for_table(self, table): + explicit_label_column = self.ds.table_metadata(self.name, table).get( + "label_column" + ) + if explicit_label_column: + return explicit_label_column + # If a table has two columns, one of which is ID, then label_column is the other one + column_names = await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: table_columns(conn, table) + ) + # Is there a name or title column? + name_or_title = [c for c in column_names if c in ("name", "title")] + if name_or_title: + return name_or_title[0] + if ( + column_names + and len(column_names) == 2 + and ("id" in column_names or "pk" in column_names) + ): + return [c for c in column_names if c not in ("id", "pk")][0] + # Couldn't find a label: + return None + + async def foreign_keys_for_table(self, table): + return await self.ds.execute_against_connection_in_thread( + self.name, lambda conn: get_outbound_foreign_keys(conn, table) + ) + + async def hidden_table_names(self): + # Mark tables 'hidden' if they relate to FTS virtual tables + hidden_tables = [ + r[0] + for r in ( + await self.ds.execute( + self.name, + """ + select name from sqlite_master + where rootpage = 0 + and sql like '%VIRTUAL TABLE%USING FTS%' + """, + ) + ).rows + ] + has_spatialite = await self.ds.execute_against_connection_in_thread( + self.name, detect_spatialite + ) + if has_spatialite: + # Also hide Spatialite internal tables + hidden_tables += [ + "ElementaryGeometries", + "SpatialIndex", + "geometry_columns", + "spatial_ref_sys", + "spatialite_history", + "sql_statements_log", + "sqlite_sequence", + "views_geometry_columns", + "virts_geometry_columns", + ] + [ + r[0] + for r in ( + await self.ds.execute( + self.name, + """ + select name from sqlite_master + where name like "idx_%" + and type = "table" + """, + ) + ).rows + ] + # Add any from metadata.json + db_metadata = self.ds.metadata(database=self.name) + if "tables" in db_metadata: + hidden_tables += [ + t + for t in db_metadata["tables"] + if db_metadata["tables"][t].get("hidden") + ] + # Also mark as hidden any tables which start with the name of a hidden table + # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden + for table_name in await self.table_names(): + for hidden_table in hidden_tables[:]: + if table_name.startswith(hidden_table): + hidden_tables.append(table_name) + continue + + return hidden_tables + + async def view_names(self): + results = await self.ds.execute( + self.name, "select name from sqlite_master where type='view'" + ) + return [r[0] for r in results.rows] + + async def get_all_foreign_keys(self): + return await self.ds.execute_against_connection_in_thread( + self.name, get_all_foreign_keys + ) + + async def get_table_definition(self, table, type_="table"): + table_definition_rows = list( + await self.ds.execute( + self.name, + "select sql from sqlite_master where name = :n and type=:t", + {"n": table, "t": type_}, + ) + ) + if not table_definition_rows: + return None + return table_definition_rows[0][0] + + async def get_view_definition(self, view): + return await self.get_table_definition(view, "view") + + def __repr__(self): + tags = [] + if self.is_mutable: + tags.append("mutable") + if self.is_memory: + tags.append("memory") + if self.hash: + tags.append("hash={}".format(self.hash)) + if self.size is not None: + tags.append("size={}".format(self.size)) + tags_str = "" + if tags: + tags_str = " ({})".format(", ".join(tags)) + return "".format(self.name, tags_str) diff --git a/datasette/views/database.py b/datasette/views/database.py index d7b07762..7172b346 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -5,7 +5,6 @@ from sanic import response from datasette.utils import ( detect_fts, detect_primary_keys, - get_all_foreign_keys, to_css_class, validate_sql_select, )