diff --git a/datasette/app.py b/datasette/app.py index e7a5ce48..79692ff3 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -68,9 +68,7 @@ class BaseView(RenderMixin): self.ds = datasette self.files = datasette.files self.jinja_env = datasette.jinja_env - self.executor = datasette.executor self.page_size = datasette.page_size - self.max_returned_rows = datasette.max_returned_rows def options(self, request, *args, **kwargs): r = response.text('ok') @@ -91,7 +89,7 @@ class BaseView(RenderMixin): async def pks_for_table(self, name, table): rows = [ - row for row in await self.execute( + row for row in await self.ds.execute( name, 'PRAGMA table_info("{}")'.format(table) ) @@ -135,49 +133,6 @@ class BaseView(RenderMixin): return name, expected, should_redirect return name, expected, None - async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None): - """Executes sql against db_name in a thread""" - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.ds.inspect()[db_name] - conn = sqlite3.connect( - 'file:{}?immutable=1'.format(info['file']), - uri=True, - check_same_thread=False, - ) - self.ds.prepare_connection(conn) - setattr(connections, db_name, conn) - - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params or {}) - if self.max_returned_rows and truncate: - rows = cursor.fetchmany(self.max_returned_rows + 1) - truncated = len(rows) > self.max_returned_rows - rows = rows[:self.max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except Exception: - print('ERROR: conn={}, sql = {}, params = {}'.format( - conn, repr(sql), params - )) - raise - if truncate: - return rows, truncated, cursor.description - else: - return rows - - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - def get_templates(self, database, table=None): assert NotImplemented @@ -192,6 +147,7 @@ class BaseView(RenderMixin): as_json = kwargs.pop('as_json') except KeyError: as_json = False + table = kwargs.get('table', None) extra_template_data = {} start = time.time() status_code = 200 @@ -229,6 +185,23 @@ class BaseView(RenderMixin): dict(zip(columns, row)) for row in rows ] + elif '_shape' in request.args: + # Re-shape it + shape = request.raw_args['_shape'] + if shape in ('objects', 'object'): + columns = data.get('columns') + rows = data.get('rows') + if rows and columns: + data['rows'] = [ + dict(zip(columns, row)) + for row in rows + ] + if shape == 'object' and table: + pks = await self.pks_for_table(name, table) + data['rows'] = { + path_from_row_pks(row, pks, not pks): row + for row in data['rows'] + } headers = {} if self.ds.cors: headers['Access-Control-Allow-Origin'] = '*' @@ -292,7 +265,7 @@ class BaseView(RenderMixin): extra_args = {} if params.get('_sql_time_limit_ms'): extra_args['custom_time_limit'] = int(params['_sql_time_limit_ms']) - rows, truncated, description = await self.execute( + rows, truncated, description = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) columns = [r[0] for r in description] @@ -326,7 +299,6 @@ class IndexView(RenderMixin): self.ds = datasette self.files = datasette.files self.jinja_env = datasette.jinja_env - self.executor = datasette.executor async def get(self, request, as_json): databases = [] @@ -441,7 +413,7 @@ class RowTableShared(BaseView): placeholders=', '.join(['?'] * len(ids_to_lookup)), ) try: - results = await self.execute(database, sql, list(set(ids_to_lookup))) + results = await self.ds.execute(database, sql, list(set(ids_to_lookup))) except sqlite3.OperationalError: # Probably hit the timelimit pass @@ -504,17 +476,17 @@ class TableView(RowTableShared): if canned_query is not None: return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table) pks = await self.pks_for_table(name, table) - is_view = bool(list(await self.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", { + is_view = bool(list(await self.ds.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", { 'n': table, }))[0][0]) view_definition = None table_definition = None if is_view: - view_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="view"', { + view_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="view"', { 'n': table, }))[0][0] else: - table_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="table"', { + table_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="table"', { 'n': table, }))[0][0] use_rowid = not pks and not is_view @@ -562,7 +534,7 @@ class TableView(RowTableShared): # _search support: fts_table = None fts_sql = detect_fts_sql(table) - fts_rows = list(await self.execute(name, fts_sql)) + fts_rows = list(await self.ds.execute(name, fts_sql)) if fts_rows: fts_table = fts_rows[0][0] @@ -638,7 +610,7 @@ class TableView(RowTableShared): if request.raw_args.get('_sql_time_limit_ms'): extra_args['custom_time_limit'] = int(request.raw_args['_sql_time_limit_ms']) - rows, truncated, description = await self.execute( + rows, truncated, description = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) @@ -678,7 +650,7 @@ class TableView(RowTableShared): # Attempt a full count, if we can do it in < X ms if count_sql: try: - count_rows = list(await self.execute(name, count_sql, params)) + count_rows = list(await self.ds.execute(name, count_sql, params)) filtered_table_rows = count_rows[0][0] except sqlite3.OperationalError: # Almost certainly hit the timeout @@ -689,7 +661,7 @@ class TableView(RowTableShared): async def extra_template(): display_columns, display_rows = await self.display_columns_and_rows( - name, table, description, rows, link_column=not is_view, expand_foreign_keys=True + name, table, description, rows, link_column=not is_view, expand_foreign_keys=not is_view ) return { 'database_hash': hash, @@ -755,8 +727,8 @@ class RowView(RowTableShared): params = {} for i, pk_value in enumerate(pk_values): params['p{}'.format(i)] = pk_value - # rows, truncated, description = await self.execute(name, sql, params, truncate=True) - rows, truncated, description = await self.execute(name, sql, params, truncate=True) + # rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True) + rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True) columns = [r[0] for r in description] rows = list(rows) if not rows: @@ -813,7 +785,7 @@ class RowView(RowTableShared): for fk in foreign_keys ]) try: - rows = list(await self.execute(name, sql, {'id': pk_values[0]})) + rows = list(await self.ds.execute(name, sql, {'id': pk_values[0]})) except sqlite3.OperationalError: # Almost certainly hit the timeout return [] @@ -826,7 +798,8 @@ class RowView(RowTableShared): foreign_key_tables = [] for fk in foreign_keys: count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0 - foreign_key_tables.append({**fk, **{'count': count}}) + if count: + foreign_key_tables.append({**fk, **{'count': count}}) return foreign_key_tables @@ -835,8 +808,8 @@ class Datasette: self, files, num_threads=3, cache_headers=True, page_size=100, max_returned_rows=1000, sql_time_limit_ms=1000, cors=False, inspect_data=None, metadata=None, sqlite_extensions=None, - template_dir=None, static_mounts=None): - self.files = files + template_dir=None, static_mounts=None, memory=None): + self.files = list(files) self.num_threads = num_threads self.executor = futures.ThreadPoolExecutor( max_workers=num_threads @@ -852,6 +825,90 @@ class Datasette: self.sqlite_extensions = sqlite_extensions or [] self.template_dir = template_dir self.static_mounts = static_mounts or [] + self.memory = memory or [] + for filepath in self.memory: + if filepath not in self.files: + self.files.append(filepath) + + async def initialize_memory_connections(self): + print('initialize_memory_connections') + for name, info in self.inspect().items(): + print(name, info['memory']) + if info['memory']: + await self.execute(name, 'select 1') + + async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None): + """Executes sql against db_name in a thread""" + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = self.inspect()[db_name] + if info['file'] in self.memory: + # Copy file into an in-memory database + memory_conn = sqlite3.connect( + 'file:{}?mode=memory&cache=shared'.format(db_name), + uri=True, + check_same_thread=False, + ) + self.prepare_connection(memory_conn) + # Do we need to copy data across? + if not memory_conn.execute("select count(*) from sqlite_master where type='table'").fetchone()[0]: + conn = sqlite3.connect( + 'file:{}?immutable=1'.format(info['file']), + uri=True, + check_same_thread=False, + ) + self.prepare_connection(conn) + pages_todo = conn.execute('PRAGMA page_count').fetchone()[0] + print('Here we go...') + i = 0 + for line in conn.iterdump(): + memory_conn.execute(line) + i += 1 + if i % 10000 == 0: + pages_done = memory_conn.execute('PRAGMA page_count').fetchone()[0] + print('Done {}/{} {:.2f}%'.format( + pages_done, pages_todo, (pages_done / pages_todo * 100) + )) + conn.close() + conn = memory_conn + else: + conn = sqlite3.connect( + 'file:{}?immutable=1'.format(info['file']), + uri=True, + check_same_thread=False, + ) + self.prepare_connection(conn) + setattr(connections, db_name, conn) + + time_limit_ms = self.sql_time_limit_ms + if custom_time_limit and custom_time_limit < self.sql_time_limit_ms: + time_limit_ms = custom_time_limit + + with sqlite_timelimit(conn, time_limit_ms): + try: + cursor = conn.cursor() + cursor.execute(sql, params or {}) + if self.max_returned_rows and truncate: + rows = cursor.fetchmany(self.max_returned_rows + 1) + truncated = len(rows) > self.max_returned_rows + rows = rows[:self.max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except Exception: + print('ERROR: conn={}, sql = {}, params = {}'.format( + conn, repr(sql), params + )) + raise + if truncate: + return rows, truncated, cursor.description + else: + return rows + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) def app_css_hash(self): if not hasattr(self, '_app_css_hash'): @@ -975,6 +1032,10 @@ class Datasette: 'views': views, } + # Annotate in-memory databases + memory_names = {Path(filename).stem for filename in self.memory} + for name, info in self._inspect.items(): + info['memory'] = name in memory_names return self._inspect def app(self): @@ -1020,4 +1081,6 @@ class Datasette: RowView.as_view(self), '///' ) + if self.memory: + app.add_task(self.initialize_memory_connections) return app diff --git a/datasette/cli.py b/datasette/cli.py index 5744d00a..055d6f92 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -233,7 +233,8 @@ def package(files, tag, metadata, extra_options, branch, template_dir, static, * @click.option('-m', '--metadata', type=click.File(mode='r'), help='Path to JSON file containing license/source metadata') @click.option('--template-dir', type=click.Path(exists=True, file_okay=False, dir_okay=True), help='Path to directory containing custom templates') @click.option('--static', type=StaticMount(), help='mountpoint:path-to-directory for serving static files', multiple=True) -def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static): +@click.option('--memory', type=click.Path(exists=True), help='database files to load into memory', multiple=True) +def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static, memory): """Serve up specified SQLite database files with a web UI""" if reload: import hupper @@ -262,6 +263,7 @@ def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sqlite_extensions=sqlite_extensions, template_dir=template_dir, static_mounts=static, + memory=memory, ) # Force initial hashing/table counting ds.inspect()