WIP: --memory option for loading entire database into :memory:

This commit is contained in:
Simon Willison 2018-01-09 20:45:11 -08:00
commit b053fa4a5d
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
2 changed files with 128 additions and 63 deletions

View file

@ -68,9 +68,7 @@ class BaseView(RenderMixin):
self.ds = datasette
self.files = datasette.files
self.jinja_env = datasette.jinja_env
self.executor = datasette.executor
self.page_size = datasette.page_size
self.max_returned_rows = datasette.max_returned_rows
def options(self, request, *args, **kwargs):
r = response.text('ok')
@ -91,7 +89,7 @@ class BaseView(RenderMixin):
async def pks_for_table(self, name, table):
rows = [
row for row in await self.execute(
row for row in await self.ds.execute(
name,
'PRAGMA table_info("{}")'.format(table)
)
@ -135,49 +133,6 @@ class BaseView(RenderMixin):
return name, expected, should_redirect
return name, expected, None
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
"""Executes sql against db_name in a thread"""
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.ds.inspect()[db_name]
conn = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True,
check_same_thread=False,
)
self.ds.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
if self.max_returned_rows and truncate:
rows = cursor.fetchmany(self.max_returned_rows + 1)
truncated = len(rows) > self.max_returned_rows
rows = rows[:self.max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except Exception:
print('ERROR: conn={}, sql = {}, params = {}'.format(
conn, repr(sql), params
))
raise
if truncate:
return rows, truncated, cursor.description
else:
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def get_templates(self, database, table=None):
assert NotImplemented
@ -192,6 +147,7 @@ class BaseView(RenderMixin):
as_json = kwargs.pop('as_json')
except KeyError:
as_json = False
table = kwargs.get('table', None)
extra_template_data = {}
start = time.time()
status_code = 200
@ -229,6 +185,23 @@ class BaseView(RenderMixin):
dict(zip(columns, row))
for row in rows
]
elif '_shape' in request.args:
# Re-shape it
shape = request.raw_args['_shape']
if shape in ('objects', 'object'):
columns = data.get('columns')
rows = data.get('rows')
if rows and columns:
data['rows'] = [
dict(zip(columns, row))
for row in rows
]
if shape == 'object' and table:
pks = await self.pks_for_table(name, table)
data['rows'] = {
path_from_row_pks(row, pks, not pks): row
for row in data['rows']
}
headers = {}
if self.ds.cors:
headers['Access-Control-Allow-Origin'] = '*'
@ -292,7 +265,7 @@ class BaseView(RenderMixin):
extra_args = {}
if params.get('_sql_time_limit_ms'):
extra_args['custom_time_limit'] = int(params['_sql_time_limit_ms'])
rows, truncated, description = await self.execute(
rows, truncated, description = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
)
columns = [r[0] for r in description]
@ -326,7 +299,6 @@ class IndexView(RenderMixin):
self.ds = datasette
self.files = datasette.files
self.jinja_env = datasette.jinja_env
self.executor = datasette.executor
async def get(self, request, as_json):
databases = []
@ -441,7 +413,7 @@ class RowTableShared(BaseView):
placeholders=', '.join(['?'] * len(ids_to_lookup)),
)
try:
results = await self.execute(database, sql, list(set(ids_to_lookup)))
results = await self.ds.execute(database, sql, list(set(ids_to_lookup)))
except sqlite3.OperationalError:
# Probably hit the timelimit
pass
@ -504,17 +476,17 @@ class TableView(RowTableShared):
if canned_query is not None:
return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table)
pks = await self.pks_for_table(name, table)
is_view = bool(list(await self.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
is_view = bool(list(await self.ds.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
'n': table,
}))[0][0])
view_definition = None
table_definition = None
if is_view:
view_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
view_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
'n': table,
}))[0][0]
else:
table_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
table_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
'n': table,
}))[0][0]
use_rowid = not pks and not is_view
@ -562,7 +534,7 @@ class TableView(RowTableShared):
# _search support:
fts_table = None
fts_sql = detect_fts_sql(table)
fts_rows = list(await self.execute(name, fts_sql))
fts_rows = list(await self.ds.execute(name, fts_sql))
if fts_rows:
fts_table = fts_rows[0][0]
@ -638,7 +610,7 @@ class TableView(RowTableShared):
if request.raw_args.get('_sql_time_limit_ms'):
extra_args['custom_time_limit'] = int(request.raw_args['_sql_time_limit_ms'])
rows, truncated, description = await self.execute(
rows, truncated, description = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
)
@ -678,7 +650,7 @@ class TableView(RowTableShared):
# Attempt a full count, if we can do it in < X ms
if count_sql:
try:
count_rows = list(await self.execute(name, count_sql, params))
count_rows = list(await self.ds.execute(name, count_sql, params))
filtered_table_rows = count_rows[0][0]
except sqlite3.OperationalError:
# Almost certainly hit the timeout
@ -689,7 +661,7 @@ class TableView(RowTableShared):
async def extra_template():
display_columns, display_rows = await self.display_columns_and_rows(
name, table, description, rows, link_column=not is_view, expand_foreign_keys=True
name, table, description, rows, link_column=not is_view, expand_foreign_keys=not is_view
)
return {
'database_hash': hash,
@ -755,8 +727,8 @@ class RowView(RowTableShared):
params = {}
for i, pk_value in enumerate(pk_values):
params['p{}'.format(i)] = pk_value
# rows, truncated, description = await self.execute(name, sql, params, truncate=True)
rows, truncated, description = await self.execute(name, sql, params, truncate=True)
# rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
columns = [r[0] for r in description]
rows = list(rows)
if not rows:
@ -813,7 +785,7 @@ class RowView(RowTableShared):
for fk in foreign_keys
])
try:
rows = list(await self.execute(name, sql, {'id': pk_values[0]}))
rows = list(await self.ds.execute(name, sql, {'id': pk_values[0]}))
except sqlite3.OperationalError:
# Almost certainly hit the timeout
return []
@ -826,7 +798,8 @@ class RowView(RowTableShared):
foreign_key_tables = []
for fk in foreign_keys:
count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0
foreign_key_tables.append({**fk, **{'count': count}})
if count:
foreign_key_tables.append({**fk, **{'count': count}})
return foreign_key_tables
@ -835,8 +808,8 @@ class Datasette:
self, files, num_threads=3, cache_headers=True, page_size=100,
max_returned_rows=1000, sql_time_limit_ms=1000, cors=False,
inspect_data=None, metadata=None, sqlite_extensions=None,
template_dir=None, static_mounts=None):
self.files = files
template_dir=None, static_mounts=None, memory=None):
self.files = list(files)
self.num_threads = num_threads
self.executor = futures.ThreadPoolExecutor(
max_workers=num_threads
@ -852,6 +825,90 @@ class Datasette:
self.sqlite_extensions = sqlite_extensions or []
self.template_dir = template_dir
self.static_mounts = static_mounts or []
self.memory = memory or []
for filepath in self.memory:
if filepath not in self.files:
self.files.append(filepath)
async def initialize_memory_connections(self):
print('initialize_memory_connections')
for name, info in self.inspect().items():
print(name, info['memory'])
if info['memory']:
await self.execute(name, 'select 1')
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
"""Executes sql against db_name in a thread"""
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.inspect()[db_name]
if info['file'] in self.memory:
# Copy file into an in-memory database
memory_conn = sqlite3.connect(
'file:{}?mode=memory&cache=shared'.format(db_name),
uri=True,
check_same_thread=False,
)
self.prepare_connection(memory_conn)
# Do we need to copy data across?
if not memory_conn.execute("select count(*) from sqlite_master where type='table'").fetchone()[0]:
conn = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True,
check_same_thread=False,
)
self.prepare_connection(conn)
pages_todo = conn.execute('PRAGMA page_count').fetchone()[0]
print('Here we go...')
i = 0
for line in conn.iterdump():
memory_conn.execute(line)
i += 1
if i % 10000 == 0:
pages_done = memory_conn.execute('PRAGMA page_count').fetchone()[0]
print('Done {}/{} {:.2f}%'.format(
pages_done, pages_todo, (pages_done / pages_todo * 100)
))
conn.close()
conn = memory_conn
else:
conn = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True,
check_same_thread=False,
)
self.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.sql_time_limit_ms
if custom_time_limit and custom_time_limit < self.sql_time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
if self.max_returned_rows and truncate:
rows = cursor.fetchmany(self.max_returned_rows + 1)
truncated = len(rows) > self.max_returned_rows
rows = rows[:self.max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except Exception:
print('ERROR: conn={}, sql = {}, params = {}'.format(
conn, repr(sql), params
))
raise
if truncate:
return rows, truncated, cursor.description
else:
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def app_css_hash(self):
if not hasattr(self, '_app_css_hash'):
@ -975,6 +1032,10 @@ class Datasette:
'views': views,
}
# Annotate in-memory databases
memory_names = {Path(filename).stem for filename in self.memory}
for name, info in self._inspect.items():
info['memory'] = name in memory_names
return self._inspect
def app(self):
@ -1020,4 +1081,6 @@ class Datasette:
RowView.as_view(self),
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(\.jsono?)?$>'
)
if self.memory:
app.add_task(self.initialize_memory_connections)
return app

View file

@ -233,7 +233,8 @@ def package(files, tag, metadata, extra_options, branch, template_dir, static, *
@click.option('-m', '--metadata', type=click.File(mode='r'), help='Path to JSON file containing license/source metadata')
@click.option('--template-dir', type=click.Path(exists=True, file_okay=False, dir_okay=True), help='Path to directory containing custom templates')
@click.option('--static', type=StaticMount(), help='mountpoint:path-to-directory for serving static files', multiple=True)
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static):
@click.option('--memory', type=click.Path(exists=True), help='database files to load into memory', multiple=True)
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static, memory):
"""Serve up specified SQLite database files with a web UI"""
if reload:
import hupper
@ -262,6 +263,7 @@ def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows,
sqlite_extensions=sqlite_extensions,
template_dir=template_dir,
static_mounts=static,
memory=memory,
)
# Force initial hashing/table counting
ds.inspect()