mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
WIP: --memory option for loading entire database into :memory:
This commit is contained in:
parent
306e1c6ac4
commit
b053fa4a5d
2 changed files with 128 additions and 63 deletions
187
datasette/app.py
187
datasette/app.py
|
|
@ -68,9 +68,7 @@ class BaseView(RenderMixin):
|
|||
self.ds = datasette
|
||||
self.files = datasette.files
|
||||
self.jinja_env = datasette.jinja_env
|
||||
self.executor = datasette.executor
|
||||
self.page_size = datasette.page_size
|
||||
self.max_returned_rows = datasette.max_returned_rows
|
||||
|
||||
def options(self, request, *args, **kwargs):
|
||||
r = response.text('ok')
|
||||
|
|
@ -91,7 +89,7 @@ class BaseView(RenderMixin):
|
|||
|
||||
async def pks_for_table(self, name, table):
|
||||
rows = [
|
||||
row for row in await self.execute(
|
||||
row for row in await self.ds.execute(
|
||||
name,
|
||||
'PRAGMA table_info("{}")'.format(table)
|
||||
)
|
||||
|
|
@ -135,49 +133,6 @@ class BaseView(RenderMixin):
|
|||
return name, expected, should_redirect
|
||||
return name, expected, None
|
||||
|
||||
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
|
||||
"""Executes sql against db_name in a thread"""
|
||||
def sql_operation_in_thread():
|
||||
conn = getattr(connections, db_name, None)
|
||||
if not conn:
|
||||
info = self.ds.inspect()[db_name]
|
||||
conn = sqlite3.connect(
|
||||
'file:{}?immutable=1'.format(info['file']),
|
||||
uri=True,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.ds.prepare_connection(conn)
|
||||
setattr(connections, db_name, conn)
|
||||
|
||||
time_limit_ms = self.ds.sql_time_limit_ms
|
||||
if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms:
|
||||
time_limit_ms = custom_time_limit
|
||||
|
||||
with sqlite_timelimit(conn, time_limit_ms):
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params or {})
|
||||
if self.max_returned_rows and truncate:
|
||||
rows = cursor.fetchmany(self.max_returned_rows + 1)
|
||||
truncated = len(rows) > self.max_returned_rows
|
||||
rows = rows[:self.max_returned_rows]
|
||||
else:
|
||||
rows = cursor.fetchall()
|
||||
truncated = False
|
||||
except Exception:
|
||||
print('ERROR: conn={}, sql = {}, params = {}'.format(
|
||||
conn, repr(sql), params
|
||||
))
|
||||
raise
|
||||
if truncate:
|
||||
return rows, truncated, cursor.description
|
||||
else:
|
||||
return rows
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, sql_operation_in_thread
|
||||
)
|
||||
|
||||
def get_templates(self, database, table=None):
|
||||
assert NotImplemented
|
||||
|
||||
|
|
@ -192,6 +147,7 @@ class BaseView(RenderMixin):
|
|||
as_json = kwargs.pop('as_json')
|
||||
except KeyError:
|
||||
as_json = False
|
||||
table = kwargs.get('table', None)
|
||||
extra_template_data = {}
|
||||
start = time.time()
|
||||
status_code = 200
|
||||
|
|
@ -229,6 +185,23 @@ class BaseView(RenderMixin):
|
|||
dict(zip(columns, row))
|
||||
for row in rows
|
||||
]
|
||||
elif '_shape' in request.args:
|
||||
# Re-shape it
|
||||
shape = request.raw_args['_shape']
|
||||
if shape in ('objects', 'object'):
|
||||
columns = data.get('columns')
|
||||
rows = data.get('rows')
|
||||
if rows and columns:
|
||||
data['rows'] = [
|
||||
dict(zip(columns, row))
|
||||
for row in rows
|
||||
]
|
||||
if shape == 'object' and table:
|
||||
pks = await self.pks_for_table(name, table)
|
||||
data['rows'] = {
|
||||
path_from_row_pks(row, pks, not pks): row
|
||||
for row in data['rows']
|
||||
}
|
||||
headers = {}
|
||||
if self.ds.cors:
|
||||
headers['Access-Control-Allow-Origin'] = '*'
|
||||
|
|
@ -292,7 +265,7 @@ class BaseView(RenderMixin):
|
|||
extra_args = {}
|
||||
if params.get('_sql_time_limit_ms'):
|
||||
extra_args['custom_time_limit'] = int(params['_sql_time_limit_ms'])
|
||||
rows, truncated, description = await self.execute(
|
||||
rows, truncated, description = await self.ds.execute(
|
||||
name, sql, params, truncate=True, **extra_args
|
||||
)
|
||||
columns = [r[0] for r in description]
|
||||
|
|
@ -326,7 +299,6 @@ class IndexView(RenderMixin):
|
|||
self.ds = datasette
|
||||
self.files = datasette.files
|
||||
self.jinja_env = datasette.jinja_env
|
||||
self.executor = datasette.executor
|
||||
|
||||
async def get(self, request, as_json):
|
||||
databases = []
|
||||
|
|
@ -441,7 +413,7 @@ class RowTableShared(BaseView):
|
|||
placeholders=', '.join(['?'] * len(ids_to_lookup)),
|
||||
)
|
||||
try:
|
||||
results = await self.execute(database, sql, list(set(ids_to_lookup)))
|
||||
results = await self.ds.execute(database, sql, list(set(ids_to_lookup)))
|
||||
except sqlite3.OperationalError:
|
||||
# Probably hit the timelimit
|
||||
pass
|
||||
|
|
@ -504,17 +476,17 @@ class TableView(RowTableShared):
|
|||
if canned_query is not None:
|
||||
return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table)
|
||||
pks = await self.pks_for_table(name, table)
|
||||
is_view = bool(list(await self.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
|
||||
is_view = bool(list(await self.ds.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
|
||||
'n': table,
|
||||
}))[0][0])
|
||||
view_definition = None
|
||||
table_definition = None
|
||||
if is_view:
|
||||
view_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
|
||||
view_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
|
||||
'n': table,
|
||||
}))[0][0]
|
||||
else:
|
||||
table_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
|
||||
table_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
|
||||
'n': table,
|
||||
}))[0][0]
|
||||
use_rowid = not pks and not is_view
|
||||
|
|
@ -562,7 +534,7 @@ class TableView(RowTableShared):
|
|||
# _search support:
|
||||
fts_table = None
|
||||
fts_sql = detect_fts_sql(table)
|
||||
fts_rows = list(await self.execute(name, fts_sql))
|
||||
fts_rows = list(await self.ds.execute(name, fts_sql))
|
||||
if fts_rows:
|
||||
fts_table = fts_rows[0][0]
|
||||
|
||||
|
|
@ -638,7 +610,7 @@ class TableView(RowTableShared):
|
|||
if request.raw_args.get('_sql_time_limit_ms'):
|
||||
extra_args['custom_time_limit'] = int(request.raw_args['_sql_time_limit_ms'])
|
||||
|
||||
rows, truncated, description = await self.execute(
|
||||
rows, truncated, description = await self.ds.execute(
|
||||
name, sql, params, truncate=True, **extra_args
|
||||
)
|
||||
|
||||
|
|
@ -678,7 +650,7 @@ class TableView(RowTableShared):
|
|||
# Attempt a full count, if we can do it in < X ms
|
||||
if count_sql:
|
||||
try:
|
||||
count_rows = list(await self.execute(name, count_sql, params))
|
||||
count_rows = list(await self.ds.execute(name, count_sql, params))
|
||||
filtered_table_rows = count_rows[0][0]
|
||||
except sqlite3.OperationalError:
|
||||
# Almost certainly hit the timeout
|
||||
|
|
@ -689,7 +661,7 @@ class TableView(RowTableShared):
|
|||
|
||||
async def extra_template():
|
||||
display_columns, display_rows = await self.display_columns_and_rows(
|
||||
name, table, description, rows, link_column=not is_view, expand_foreign_keys=True
|
||||
name, table, description, rows, link_column=not is_view, expand_foreign_keys=not is_view
|
||||
)
|
||||
return {
|
||||
'database_hash': hash,
|
||||
|
|
@ -755,8 +727,8 @@ class RowView(RowTableShared):
|
|||
params = {}
|
||||
for i, pk_value in enumerate(pk_values):
|
||||
params['p{}'.format(i)] = pk_value
|
||||
# rows, truncated, description = await self.execute(name, sql, params, truncate=True)
|
||||
rows, truncated, description = await self.execute(name, sql, params, truncate=True)
|
||||
# rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
|
||||
rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
|
||||
columns = [r[0] for r in description]
|
||||
rows = list(rows)
|
||||
if not rows:
|
||||
|
|
@ -813,7 +785,7 @@ class RowView(RowTableShared):
|
|||
for fk in foreign_keys
|
||||
])
|
||||
try:
|
||||
rows = list(await self.execute(name, sql, {'id': pk_values[0]}))
|
||||
rows = list(await self.ds.execute(name, sql, {'id': pk_values[0]}))
|
||||
except sqlite3.OperationalError:
|
||||
# Almost certainly hit the timeout
|
||||
return []
|
||||
|
|
@ -826,7 +798,8 @@ class RowView(RowTableShared):
|
|||
foreign_key_tables = []
|
||||
for fk in foreign_keys:
|
||||
count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0
|
||||
foreign_key_tables.append({**fk, **{'count': count}})
|
||||
if count:
|
||||
foreign_key_tables.append({**fk, **{'count': count}})
|
||||
return foreign_key_tables
|
||||
|
||||
|
||||
|
|
@ -835,8 +808,8 @@ class Datasette:
|
|||
self, files, num_threads=3, cache_headers=True, page_size=100,
|
||||
max_returned_rows=1000, sql_time_limit_ms=1000, cors=False,
|
||||
inspect_data=None, metadata=None, sqlite_extensions=None,
|
||||
template_dir=None, static_mounts=None):
|
||||
self.files = files
|
||||
template_dir=None, static_mounts=None, memory=None):
|
||||
self.files = list(files)
|
||||
self.num_threads = num_threads
|
||||
self.executor = futures.ThreadPoolExecutor(
|
||||
max_workers=num_threads
|
||||
|
|
@ -852,6 +825,90 @@ class Datasette:
|
|||
self.sqlite_extensions = sqlite_extensions or []
|
||||
self.template_dir = template_dir
|
||||
self.static_mounts = static_mounts or []
|
||||
self.memory = memory or []
|
||||
for filepath in self.memory:
|
||||
if filepath not in self.files:
|
||||
self.files.append(filepath)
|
||||
|
||||
async def initialize_memory_connections(self):
|
||||
print('initialize_memory_connections')
|
||||
for name, info in self.inspect().items():
|
||||
print(name, info['memory'])
|
||||
if info['memory']:
|
||||
await self.execute(name, 'select 1')
|
||||
|
||||
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
|
||||
"""Executes sql against db_name in a thread"""
|
||||
def sql_operation_in_thread():
|
||||
conn = getattr(connections, db_name, None)
|
||||
if not conn:
|
||||
info = self.inspect()[db_name]
|
||||
if info['file'] in self.memory:
|
||||
# Copy file into an in-memory database
|
||||
memory_conn = sqlite3.connect(
|
||||
'file:{}?mode=memory&cache=shared'.format(db_name),
|
||||
uri=True,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.prepare_connection(memory_conn)
|
||||
# Do we need to copy data across?
|
||||
if not memory_conn.execute("select count(*) from sqlite_master where type='table'").fetchone()[0]:
|
||||
conn = sqlite3.connect(
|
||||
'file:{}?immutable=1'.format(info['file']),
|
||||
uri=True,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.prepare_connection(conn)
|
||||
pages_todo = conn.execute('PRAGMA page_count').fetchone()[0]
|
||||
print('Here we go...')
|
||||
i = 0
|
||||
for line in conn.iterdump():
|
||||
memory_conn.execute(line)
|
||||
i += 1
|
||||
if i % 10000 == 0:
|
||||
pages_done = memory_conn.execute('PRAGMA page_count').fetchone()[0]
|
||||
print('Done {}/{} {:.2f}%'.format(
|
||||
pages_done, pages_todo, (pages_done / pages_todo * 100)
|
||||
))
|
||||
conn.close()
|
||||
conn = memory_conn
|
||||
else:
|
||||
conn = sqlite3.connect(
|
||||
'file:{}?immutable=1'.format(info['file']),
|
||||
uri=True,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.prepare_connection(conn)
|
||||
setattr(connections, db_name, conn)
|
||||
|
||||
time_limit_ms = self.sql_time_limit_ms
|
||||
if custom_time_limit and custom_time_limit < self.sql_time_limit_ms:
|
||||
time_limit_ms = custom_time_limit
|
||||
|
||||
with sqlite_timelimit(conn, time_limit_ms):
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params or {})
|
||||
if self.max_returned_rows and truncate:
|
||||
rows = cursor.fetchmany(self.max_returned_rows + 1)
|
||||
truncated = len(rows) > self.max_returned_rows
|
||||
rows = rows[:self.max_returned_rows]
|
||||
else:
|
||||
rows = cursor.fetchall()
|
||||
truncated = False
|
||||
except Exception:
|
||||
print('ERROR: conn={}, sql = {}, params = {}'.format(
|
||||
conn, repr(sql), params
|
||||
))
|
||||
raise
|
||||
if truncate:
|
||||
return rows, truncated, cursor.description
|
||||
else:
|
||||
return rows
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, sql_operation_in_thread
|
||||
)
|
||||
|
||||
def app_css_hash(self):
|
||||
if not hasattr(self, '_app_css_hash'):
|
||||
|
|
@ -975,6 +1032,10 @@ class Datasette:
|
|||
'views': views,
|
||||
|
||||
}
|
||||
# Annotate in-memory databases
|
||||
memory_names = {Path(filename).stem for filename in self.memory}
|
||||
for name, info in self._inspect.items():
|
||||
info['memory'] = name in memory_names
|
||||
return self._inspect
|
||||
|
||||
def app(self):
|
||||
|
|
@ -1020,4 +1081,6 @@ class Datasette:
|
|||
RowView.as_view(self),
|
||||
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(\.jsono?)?$>'
|
||||
)
|
||||
if self.memory:
|
||||
app.add_task(self.initialize_memory_connections)
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -233,7 +233,8 @@ def package(files, tag, metadata, extra_options, branch, template_dir, static, *
|
|||
@click.option('-m', '--metadata', type=click.File(mode='r'), help='Path to JSON file containing license/source metadata')
|
||||
@click.option('--template-dir', type=click.Path(exists=True, file_okay=False, dir_okay=True), help='Path to directory containing custom templates')
|
||||
@click.option('--static', type=StaticMount(), help='mountpoint:path-to-directory for serving static files', multiple=True)
|
||||
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static):
|
||||
@click.option('--memory', type=click.Path(exists=True), help='database files to load into memory', multiple=True)
|
||||
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static, memory):
|
||||
"""Serve up specified SQLite database files with a web UI"""
|
||||
if reload:
|
||||
import hupper
|
||||
|
|
@ -262,6 +263,7 @@ def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows,
|
|||
sqlite_extensions=sqlite_extensions,
|
||||
template_dir=template_dir,
|
||||
static_mounts=static,
|
||||
memory=memory,
|
||||
)
|
||||
# Force initial hashing/table counting
|
||||
ds.inspect()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue