import asyncio import atexit from collections import namedtuple import inspect import os from pathlib import Path import queue import sqlite_utils import sys import tempfile import threading import uuid from .tracer import trace from .utils import ( call_with_supported_arguments, detect_fts, detect_primary_keys, detect_spatialite, get_all_foreign_keys, get_outbound_foreign_keys, md5_not_usedforsecurity, sqlite_timelimit, sqlite3, table_columns, table_column_details, ) from .utils.sql_analysis import SQLAnalysis, analyze_sql_tables from .utils.sqlite import sqlite_version from .inspect import inspect_hash connections = threading.local() AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) class DatasetteClosedError(RuntimeError): """Raised when using a Datasette or Database instance after close().""" _SHUTDOWN = object() class Database: # For table counts stop at this many rows: count_limit = 10000 _thread_local_id_counter = 1 def __init__( self, ds, path=None, is_mutable=True, is_memory=False, memory_name=None, mode=None, is_temp_disk=False, ): self.name = None self._thread_local_id = f"x{self._thread_local_id_counter}" Database._thread_local_id_counter += 1 self.route = None self.ds = ds self.path = path self.is_mutable = is_mutable self.is_memory = is_memory self.memory_name = memory_name self.is_temp_disk = is_temp_disk if memory_name is not None: self.is_memory = True if is_temp_disk: fd, temp_path = tempfile.mkstemp(suffix=".db", prefix="datasette_temp_") os.close(fd) self.path = temp_path self.is_mutable = True self.mode = "rwc" self._wal_enabled = False atexit.register(self._cleanup_temp_file) else: self._wal_enabled = False self.cached_hash = None self.cached_size = None self._cached_table_counts = None self._write_thread = None self._write_queue = None self._closed = False self._pending_execute_futures = set() self._pending_execute_futures_lock = threading.Lock() # These are used when in non-threaded mode: self._read_connection = None self._write_connection = None # This is used to track all file connections so they can be closed self._all_file_connections = [] if not is_temp_disk: self.mode = mode def _check_not_closed(self): if self._closed: raise DatasetteClosedError( "Database {!r} has been closed".format(self.name) ) def _remove_pending_execute_future(self, future): with self._pending_execute_futures_lock: self._pending_execute_futures.discard(future) @property def cached_table_counts(self): if self._cached_table_counts is not None: return self._cached_table_counts # Maybe use self.ds.inspect_data to populate cached_table_counts if self.ds.inspect_data and self.ds.inspect_data.get(self.name): self._cached_table_counts = { key: value["count"] for key, value in self.ds.inspect_data[self.name]["tables"].items() } return self._cached_table_counts @property def color(self): if self.hash: return self.hash[:6] return md5_not_usedforsecurity(self.name)[:6] def suggest_name(self): if self.is_temp_disk: return "_temp_disk" if self.path: return Path(self.path).stem elif self.memory_name: return self.memory_name else: return "db" def connect(self, write=False): extra_kwargs = {} if write: extra_kwargs["isolation_level"] = "IMMEDIATE" if self.memory_name: uri = "file:{}?mode=memory&cache=shared".format(self.memory_name) conn = sqlite3.connect( uri, uri=True, check_same_thread=False, **extra_kwargs ) if not write: conn.execute("PRAGMA query_only=1") return conn if self.is_memory: return sqlite3.connect(":memory:", uri=True) # mode=ro or immutable=1? if self.is_mutable: qs = "?mode=ro" if self.ds.nolock: qs += "&nolock=1" else: qs = "?immutable=1" assert not (write and not self.is_mutable) if write: qs = "" if self.mode is not None: qs = f"?mode={self.mode}" conn = sqlite3.connect( f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs ) self._all_file_connections.append(conn) if self.is_temp_disk and not self._wal_enabled: conn.execute("PRAGMA journal_mode=WAL") self._wal_enabled = True return conn def close(self): """Release all resources held by this database. Idempotent. After close() further calls to execute()/execute_fn()/ execute_write()/execute_write_fn() raise DatasetteClosedError. """ if self._closed: return with self._pending_execute_futures_lock: if self._closed: return self._closed = True pending_execute_futures = tuple(self._pending_execute_futures) # Shut down the write thread, if any, via a sentinel. The thread # drains any writes already queued before the sentinel and then # closes its own write connection and returns. write_thread = self._write_thread if write_thread is not None and self._write_queue is not None: self._write_queue.put(_SHUTDOWN) write_thread.join(timeout=10) if write_thread.is_alive(): sys.stderr.write( "Datasette: write thread for {!r} did not exit within 10s\n".format( self.name ) ) sys.stderr.flush() for future in pending_execute_futures: try: future.result() except Exception: pass # Close anything still tracked in _all_file_connections for connection in self._all_file_connections: try: connection.close() except Exception: pass self._all_file_connections = [] # Drop per-thread cached read connections we can reach try: delattr(connections, self._thread_local_id) except AttributeError: pass # Close non-threaded-mode cached connections if still open if self._read_connection is not None: try: self._read_connection.close() except Exception: pass self._read_connection = None if self._write_connection is not None: try: self._write_connection.close() except Exception: pass self._write_connection = None if self.is_temp_disk: self._cleanup_temp_file() def _cleanup_temp_file(self): if self.is_temp_disk and self.path: for suffix in ("", "-wal", "-shm"): try: os.unlink(self.path + suffix) except OSError: pass async def execute_write(self, sql, params=None, block=True, request=None): self._check_not_closed() def _inner(conn): return conn.execute(sql, params or []) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_write_fn(_inner, block=block, request=request) return results async def execute_write_script(self, sql, block=True, request=None): self._check_not_closed() def _inner(conn): return conn.executescript(sql) with trace("sql", database=self.name, sql=sql.strip(), executescript=True): results = await self.execute_write_fn( _inner, block=block, transaction=False, request=request ) return results async def execute_write_many(self, sql, params_seq, block=True, request=None): self._check_not_closed() def _inner(conn): count = 0 def count_params(params): nonlocal count for param in params: count += 1 yield param return conn.executemany(sql, count_params(params_seq)), count with trace( "sql", database=self.name, sql=sql.strip(), executemany=True ) as kwargs: results, count = await self.execute_write_fn( _inner, block=block, request=request ) kwargs["count"] = count return results async def execute_isolated_fn(self, fn): self._check_not_closed() # Open a new connection just for the duration of this function # blocking the write queue to avoid any writes occurring during it if self.ds.executor is None: # non-threaded mode isolated_connection = self.connect(write=True) try: result = fn(isolated_connection) finally: isolated_connection.close() try: self._all_file_connections.remove(isolated_connection) except ValueError: # Was probably a memory connection pass return result else: # Threaded mode - send to write thread return await self._send_to_write_thread(fn, isolated_connection=True) async def analyze_sql(self, sql, params=None) -> SQLAnalysis: self._check_not_closed() return await self.execute_isolated_fn( lambda conn: analyze_sql_tables(conn, sql, params, database_name=self.name) ) async def execute_write_fn(self, fn, block=True, transaction=True, request=None): self._check_not_closed() pending_events = [] def track_event(event): pending_events.append(event) fn = self._wrap_fn_with_hooks(fn, request, transaction, track_event) if self.ds.executor is None: # non-threaded mode if self._write_connection is None: self._write_connection = self.connect(write=True) self.ds._prepare_connection(self._write_connection, self.name) if transaction: with self._write_connection: result = fn(self._write_connection) else: result = fn(self._write_connection) else: result = await self._send_to_write_thread( fn, block=block, transaction=transaction ) if block: for event in pending_events: await self.ds.track_event(event) else: # For non-blocking writes, spawn a background task to # dispatch events after the write thread completes task_id, reply_future = result async def _dispatch_events_after_write(): try: await reply_future except Exception: # if the write failed, don't emit success events return for event in pending_events: await self.ds.track_event(event) asyncio.ensure_future(_dispatch_events_after_write()) result = task_id return result def _wrap_fn_with_hooks(self, fn, request, transaction, track_event): from .plugins import pm # Wrap fn so it receives track_event if its signature supports it. # Historically fn was called positionally, so any single-parameter # name (conn, connection, db, ...) worked. Preserve that by only # switching to keyword dependency injection when the callback # explicitly opts in by declaring a `track_event` parameter. original_fn = fn if "track_event" in inspect.signature(original_fn).parameters: def fn_with_track_event(conn): return call_with_supported_arguments( original_fn, conn=conn, track_event=track_event ) fn = fn_with_track_event wrappers = pm.hook.write_wrapper( datasette=self.ds, database=self.name, request=request, transaction=transaction, ) wrappers = [w for w in wrappers if w is not None] if not wrappers: return fn # Build the wrapped fn by nesting context manager generators. # The first wrapper returned by pluggy is outermost. for wrapper_factory in reversed(wrappers): fn = _apply_write_wrapper(fn, wrapper_factory, track_event) return fn async def _send_to_write_thread( self, fn, block=True, isolated_connection=False, transaction=True ): if self._write_queue is None: self._write_queue = queue.Queue() if self._write_thread is None: self._write_thread = threading.Thread( target=self._execute_writes, daemon=True ) self._write_thread.name = "_execute_writes for database {}".format( self.name ) self._write_thread.start() task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") loop = asyncio.get_running_loop() reply_future = loop.create_future() self._write_queue.put( WriteTask(fn, task_id, loop, reply_future, isolated_connection, transaction) ) if block: return await reply_future else: return task_id, reply_future def _execute_writes(self): # Infinite looping thread that protects the single write connection # to this database conn_exception = None conn = None try: conn = self.connect(write=True) self.ds._prepare_connection(conn, self.name) except Exception as e: conn_exception = e while True: task = self._write_queue.get() if task is _SHUTDOWN: if conn is not None: try: conn.close() except Exception: pass return exception = None result = None if conn_exception is not None: exception = conn_exception elif task.isolated_connection: isolated_connection = self.connect(write=True) try: result = task.fn(isolated_connection) except Exception as e: sys.stderr.write("{}\n".format(e)) sys.stderr.flush() exception = e finally: isolated_connection.close() try: self._all_file_connections.remove(isolated_connection) except ValueError: # Was probably a memory connection pass else: try: if task.transaction: with conn: result = task.fn(conn) else: result = task.fn(conn) except Exception as e: sys.stderr.write("{}\n".format(e)) sys.stderr.flush() exception = e _deliver_write_result(task, result, exception) async def execute_fn(self, fn): self._check_not_closed() if self.ds.executor is None: # non-threaded mode if self._read_connection is None: self._read_connection = self.connect() self.ds._prepare_connection(self._read_connection, self.name) return fn(self._read_connection) # threaded mode def in_thread(): conn = getattr(connections, self._thread_local_id, None) if not conn: conn = self.connect() self.ds._prepare_connection(conn, self.name) setattr(connections, self._thread_local_id, conn) return fn(conn) with self._pending_execute_futures_lock: self._check_not_closed() future = self.ds.executor.submit(in_thread) self._pending_execute_futures.add(future) future.add_done_callback(self._remove_pending_execute_future) return await asyncio.wrap_future(future) async def execute( self, sql, params=None, truncate=False, custom_time_limit=None, page_size=None, log_sql_errors=True, ): """Executes sql against db_name in a thread""" self._check_not_closed() page_size = page_size or self.ds.page_size def sql_operation_in_thread(conn): time_limit_ms = self.ds.sql_time_limit_ms if custom_time_limit and custom_time_limit < time_limit_ms: time_limit_ms = custom_time_limit with sqlite_timelimit(conn, time_limit_ms): try: cursor = conn.cursor() cursor.execute(sql, params if params is not None else {}) max_returned_rows = self.ds.max_returned_rows if max_returned_rows == page_size: max_returned_rows += 1 if max_returned_rows and truncate: rows = cursor.fetchmany(max_returned_rows + 1) truncated = len(rows) > max_returned_rows rows = rows[:max_returned_rows] else: rows = cursor.fetchall() truncated = False except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: if e.args == ("interrupted",): raise QueryInterrupted(e, sql, params) if log_sql_errors: sys.stderr.write( "ERROR: conn={}, sql = {}, params = {}: {}\n".format( conn, repr(sql), params, e ) ) sys.stderr.flush() raise if truncate: return Results(rows, truncated, cursor.description) else: return Results(rows, False, cursor.description) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_fn(sql_operation_in_thread) return results @property def hash(self): if self.cached_hash is not None: return self.cached_hash elif self.is_mutable or self.is_memory or self.is_temp_disk: return None elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): self.cached_hash = self.ds.inspect_data[self.name]["hash"] return self.cached_hash else: p = Path(self.path) self.cached_hash = inspect_hash(p) return self.cached_hash @property def size(self): if self.cached_size is not None: return self.cached_size elif self.is_memory: return 0 elif self.is_mutable: return Path(self.path).stat().st_size elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): self.cached_size = self.ds.inspect_data[self.name]["size"] return self.cached_size else: self.cached_size = Path(self.path).stat().st_size return self.cached_size async def table_counts(self, limit=10): if not self.is_mutable and self.cached_table_counts is not None: return self.cached_table_counts # Try to get counts for each table, $limit timeout for each count counts = {} for table in await self.table_names(): try: table_count = ( await self.execute( f"select count(*) from (select * from [{table}] limit {self.count_limit + 1})", custom_time_limit=limit, ) ).rows[0][0] counts[table] = table_count # In some cases I saw "SQL Logic Error" here in addition to # QueryInterrupted - so we catch that too: except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError): counts[table] = None if not self.is_mutable: self._cached_table_counts = counts return counts @property def mtime_ns(self): if self.is_memory: return None return Path(self.path).stat().st_mtime_ns async def attached_databases(self): # This used to be: # select seq, name, file from pragma_database_list() where seq > 0 # But SQLite prior to 3.16.0 doesn't support pragma functions results = await self.execute("PRAGMA database_list;") # {'seq': 0, 'name': 'main', 'file': ''} return [ AttachedDatabase(*row) for row in results.rows # Filter out the SQLite internal "temp" database, refs #2557 if row["seq"] > 0 and row["name"] != "temp" ] async def table_exists(self, table): results = await self.execute( "select 1 from sqlite_master where type='table' and name=?", params=(table,) ) return bool(results.rows) async def view_exists(self, table): results = await self.execute( "select 1 from sqlite_master where type='view' and name=?", params=(table,) ) return bool(results.rows) async def table_names(self): results = await self.execute( "select name from sqlite_master where type='table' order by name" ) return [r[0] for r in results.rows] async def table_columns(self, table): return await self.execute_fn(lambda conn: table_columns(conn, table)) async def table_column_details(self, table): return await self.execute_fn(lambda conn: table_column_details(conn, table)) async def primary_keys(self, table): return await self.execute_fn(lambda conn: detect_primary_keys(conn, table)) async def fts_table(self, table): return await self.execute_fn(lambda conn: detect_fts(conn, table)) async def label_column_for_table(self, table): explicit_label_column = (await self.ds.table_config(self.name, table)).get( "label_column" ) if explicit_label_column: return explicit_label_column def column_details(conn): # Returns {column_name: (type, is_unique)} db = sqlite_utils.Database(conn) columns = db[table].columns_dict indexes = db[table].indexes details = {} for name in columns: is_unique = any( index for index in indexes if index.columns == [name] and index.unique ) details[name] = (columns[name], is_unique) return details column_details = await self.execute_fn(column_details) # Is there just one unique column that's text? unique_text_columns = [ name for name, (type_, is_unique) in column_details.items() if is_unique and type_ is str ] if len(unique_text_columns) == 1: return unique_text_columns[0] column_names = list(column_details.keys()) # Is there a name or title column? name_or_title = [c for c in column_names if c.lower() in ("name", "title")] if name_or_title: return name_or_title[0] # If a table has two columns, one of which is ID, then label_column is the other one if ( column_names and len(column_names) == 2 and ("id" in column_names or "pk" in column_names) and not set(column_names) == {"id", "pk"} ): return [c for c in column_names if c not in ("id", "pk")][0] # Couldn't find a label: return None async def foreign_keys_for_table(self, table): return await self.execute_fn( lambda conn: get_outbound_foreign_keys(conn, table) ) async def hidden_table_names(self): hidden_tables = [] # Add any tables marked as hidden in config db_config = self.ds.config.get("databases", {}).get(self.name, {}) if "tables" in db_config: hidden_tables += [ t for t in db_config["tables"] if db_config["tables"][t].get("hidden") ] if sqlite_version()[1] >= 37: hidden_tables += [x[0] for x in await self.execute(""" with shadow_tables as ( select name from pragma_table_list where [type] = 'shadow' order by name ), core_tables as ( select name from sqlite_master WHERE name in ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') OR substr(name, 1, 1) == '_' ), combined as ( select name from shadow_tables union all select name from core_tables ) select name from combined order by 1 """)] else: hidden_tables += [x[0] for x in await self.execute(""" WITH base AS ( SELECT name FROM sqlite_master WHERE name IN ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') OR substr(name, 1, 1) == '_' ), fts_suffixes AS ( SELECT column1 AS suffix FROM (VALUES ('_data'), ('_idx'), ('_docsize'), ('_content'), ('_config')) ), fts5_names AS ( SELECT name FROM sqlite_master WHERE sql LIKE '%VIRTUAL TABLE%USING FTS%' ), fts5_shadow_tables AS ( SELECT printf('%s%s', fts5_names.name, fts_suffixes.suffix) AS name FROM fts5_names JOIN fts_suffixes ), fts3_suffixes AS ( SELECT column1 AS suffix FROM (VALUES ('_content'), ('_segdir'), ('_segments'), ('_stat'), ('_docsize')) ), fts3_names AS ( SELECT name FROM sqlite_master WHERE sql LIKE '%VIRTUAL TABLE%USING FTS3%' OR sql LIKE '%VIRTUAL TABLE%USING FTS4%' ), fts3_shadow_tables AS ( SELECT printf('%s%s', fts3_names.name, fts3_suffixes.suffix) AS name FROM fts3_names JOIN fts3_suffixes ), final AS ( SELECT name FROM base UNION ALL SELECT name FROM fts5_shadow_tables UNION ALL SELECT name FROM fts3_shadow_tables ) SELECT name FROM final ORDER BY 1 """)] # Also hide any FTS tables that have a content= argument hidden_tables += [x[0] for x in await self.execute(""" SELECT name FROM sqlite_master WHERE sql LIKE '%VIRTUAL TABLE%' AND sql LIKE '%USING FTS%' AND sql LIKE '%content=%' """)] has_spatialite = await self.execute_fn(detect_spatialite) if has_spatialite: # Also hide Spatialite internal tables hidden_tables += [ "ElementaryGeometries", "SpatialIndex", "geometry_columns", "spatial_ref_sys", "spatialite_history", "sql_statements_log", "sqlite_sequence", "views_geometry_columns", "virts_geometry_columns", "data_licenses", "KNN", "KNN2", ] + [ r[0] for r in (await self.execute(""" select name from sqlite_master where name like "idx_%" and type = "table" """)).rows ] return hidden_tables async def view_names(self): results = await self.execute("select name from sqlite_master where type='view'") return [r[0] for r in results.rows] async def get_all_foreign_keys(self): return await self.execute_fn(get_all_foreign_keys) async def get_table_definition(self, table, type_="table"): table_definition_rows = list( await self.execute( "select sql from sqlite_master where name = :n and type=:t", {"n": table, "t": type_}, ) ) if not table_definition_rows: return None bits = [table_definition_rows[0][0] + ";"] # Add on any indexes index_rows = list( await self.execute( "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", {"n": table}, ) ) for index_row in index_rows: bits.append(index_row[0] + ";") return "\n".join(bits) async def get_view_definition(self, view): return await self.get_table_definition(view, "view") def __repr__(self): tags = [] if self.is_mutable: tags.append("mutable") if self.is_memory: tags.append("memory") if self.is_temp_disk: tags.append("temp_disk") if self.hash: tags.append(f"hash={self.hash}") if self.size is not None: tags.append(f"size={self.size}") tags_str = "" if tags: tags_str = f" ({', '.join(tags)})" return f"" def _apply_write_wrapper(fn, wrapper_factory, track_event): """Apply a single write_wrapper context manager around fn. ``wrapper_factory`` is a callable that takes ``(conn)`` and optionally ``track_event``, and returns a generator that yields exactly once. Code before the yield runs before ``fn(conn)``, code after the yield runs after. The result of ``fn(conn)`` is sent into the generator via ``.send()``, and any exception raised by ``fn(conn)`` is thrown via ``.throw()``. """ def wrapped(conn): gen = call_with_supported_arguments( wrapper_factory, conn=conn, track_event=track_event ) # Advance to the yield point (run "before" code) try: next(gen) except StopIteration: # Generator didn't yield — just run fn unchanged return fn(conn) # Execute the actual write try: result = fn(conn) except Exception: # Throw exception into generator so it can handle it try: gen.throw(*sys.exc_info()) except StopIteration: pass # Re-raise the original exception raise else: # Send the result back through the yield try: gen.send(result) except StopIteration: pass return result return wrapped class WriteTask: __slots__ = ( "fn", "task_id", "loop", "reply_future", "isolated_connection", "transaction", ) def __init__( self, fn, task_id, loop, reply_future, isolated_connection, transaction ): self.fn = fn self.task_id = task_id self.loop = loop self.reply_future = reply_future self.isolated_connection = isolated_connection self.transaction = transaction def _deliver_write_result(task, result, exception): # Called from the write thread. Delivers the result back to the # awaiting coroutine on its event loop via call_soon_threadsafe. def _set(): if task.reply_future.done(): # Awaiter was cancelled; nothing to do. return if exception is not None: task.reply_future.set_exception(exception) else: task.reply_future.set_result(result) try: task.loop.call_soon_threadsafe(_set) except RuntimeError: # Event loop has been closed; the awaiter is gone. pass class QueryInterrupted(Exception): def __init__(self, e, sql, params): self.e = e self.sql = sql self.params = params def __str__(self): return "QueryInterrupted: {}".format(self.e) class MultipleValues(Exception): pass class Results: def __init__(self, rows, truncated, description): self.rows = rows self.truncated = truncated self.description = description @property def columns(self): return [d[0] for d in self.description] def first(self): if self.rows: return self.rows[0] else: return None def single_value(self): if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]): return self.rows[0][0] else: raise MultipleValues def dicts(self): return [dict(row) for row in self.rows] def __iter__(self): return iter(self.rows) def __len__(self): return len(self.rows)