mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Compare commits
1 commit
main
...
in-memory-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b053fa4a5d |
2 changed files with 128 additions and 63 deletions
187
datasette/app.py
187
datasette/app.py
|
|
@ -68,9 +68,7 @@ class BaseView(RenderMixin):
|
||||||
self.ds = datasette
|
self.ds = datasette
|
||||||
self.files = datasette.files
|
self.files = datasette.files
|
||||||
self.jinja_env = datasette.jinja_env
|
self.jinja_env = datasette.jinja_env
|
||||||
self.executor = datasette.executor
|
|
||||||
self.page_size = datasette.page_size
|
self.page_size = datasette.page_size
|
||||||
self.max_returned_rows = datasette.max_returned_rows
|
|
||||||
|
|
||||||
def options(self, request, *args, **kwargs):
|
def options(self, request, *args, **kwargs):
|
||||||
r = response.text('ok')
|
r = response.text('ok')
|
||||||
|
|
@ -91,7 +89,7 @@ class BaseView(RenderMixin):
|
||||||
|
|
||||||
async def pks_for_table(self, name, table):
|
async def pks_for_table(self, name, table):
|
||||||
rows = [
|
rows = [
|
||||||
row for row in await self.execute(
|
row for row in await self.ds.execute(
|
||||||
name,
|
name,
|
||||||
'PRAGMA table_info("{}")'.format(table)
|
'PRAGMA table_info("{}")'.format(table)
|
||||||
)
|
)
|
||||||
|
|
@ -135,49 +133,6 @@ class BaseView(RenderMixin):
|
||||||
return name, expected, should_redirect
|
return name, expected, should_redirect
|
||||||
return name, expected, None
|
return name, expected, None
|
||||||
|
|
||||||
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
|
|
||||||
"""Executes sql against db_name in a thread"""
|
|
||||||
def sql_operation_in_thread():
|
|
||||||
conn = getattr(connections, db_name, None)
|
|
||||||
if not conn:
|
|
||||||
info = self.ds.inspect()[db_name]
|
|
||||||
conn = sqlite3.connect(
|
|
||||||
'file:{}?immutable=1'.format(info['file']),
|
|
||||||
uri=True,
|
|
||||||
check_same_thread=False,
|
|
||||||
)
|
|
||||||
self.ds.prepare_connection(conn)
|
|
||||||
setattr(connections, db_name, conn)
|
|
||||||
|
|
||||||
time_limit_ms = self.ds.sql_time_limit_ms
|
|
||||||
if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms:
|
|
||||||
time_limit_ms = custom_time_limit
|
|
||||||
|
|
||||||
with sqlite_timelimit(conn, time_limit_ms):
|
|
||||||
try:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(sql, params or {})
|
|
||||||
if self.max_returned_rows and truncate:
|
|
||||||
rows = cursor.fetchmany(self.max_returned_rows + 1)
|
|
||||||
truncated = len(rows) > self.max_returned_rows
|
|
||||||
rows = rows[:self.max_returned_rows]
|
|
||||||
else:
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
truncated = False
|
|
||||||
except Exception:
|
|
||||||
print('ERROR: conn={}, sql = {}, params = {}'.format(
|
|
||||||
conn, repr(sql), params
|
|
||||||
))
|
|
||||||
raise
|
|
||||||
if truncate:
|
|
||||||
return rows, truncated, cursor.description
|
|
||||||
else:
|
|
||||||
return rows
|
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
|
||||||
self.executor, sql_operation_in_thread
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_templates(self, database, table=None):
|
def get_templates(self, database, table=None):
|
||||||
assert NotImplemented
|
assert NotImplemented
|
||||||
|
|
||||||
|
|
@ -192,6 +147,7 @@ class BaseView(RenderMixin):
|
||||||
as_json = kwargs.pop('as_json')
|
as_json = kwargs.pop('as_json')
|
||||||
except KeyError:
|
except KeyError:
|
||||||
as_json = False
|
as_json = False
|
||||||
|
table = kwargs.get('table', None)
|
||||||
extra_template_data = {}
|
extra_template_data = {}
|
||||||
start = time.time()
|
start = time.time()
|
||||||
status_code = 200
|
status_code = 200
|
||||||
|
|
@ -229,6 +185,23 @@ class BaseView(RenderMixin):
|
||||||
dict(zip(columns, row))
|
dict(zip(columns, row))
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
elif '_shape' in request.args:
|
||||||
|
# Re-shape it
|
||||||
|
shape = request.raw_args['_shape']
|
||||||
|
if shape in ('objects', 'object'):
|
||||||
|
columns = data.get('columns')
|
||||||
|
rows = data.get('rows')
|
||||||
|
if rows and columns:
|
||||||
|
data['rows'] = [
|
||||||
|
dict(zip(columns, row))
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
if shape == 'object' and table:
|
||||||
|
pks = await self.pks_for_table(name, table)
|
||||||
|
data['rows'] = {
|
||||||
|
path_from_row_pks(row, pks, not pks): row
|
||||||
|
for row in data['rows']
|
||||||
|
}
|
||||||
headers = {}
|
headers = {}
|
||||||
if self.ds.cors:
|
if self.ds.cors:
|
||||||
headers['Access-Control-Allow-Origin'] = '*'
|
headers['Access-Control-Allow-Origin'] = '*'
|
||||||
|
|
@ -292,7 +265,7 @@ class BaseView(RenderMixin):
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
if params.get('_sql_time_limit_ms'):
|
if params.get('_sql_time_limit_ms'):
|
||||||
extra_args['custom_time_limit'] = int(params['_sql_time_limit_ms'])
|
extra_args['custom_time_limit'] = int(params['_sql_time_limit_ms'])
|
||||||
rows, truncated, description = await self.execute(
|
rows, truncated, description = await self.ds.execute(
|
||||||
name, sql, params, truncate=True, **extra_args
|
name, sql, params, truncate=True, **extra_args
|
||||||
)
|
)
|
||||||
columns = [r[0] for r in description]
|
columns = [r[0] for r in description]
|
||||||
|
|
@ -326,7 +299,6 @@ class IndexView(RenderMixin):
|
||||||
self.ds = datasette
|
self.ds = datasette
|
||||||
self.files = datasette.files
|
self.files = datasette.files
|
||||||
self.jinja_env = datasette.jinja_env
|
self.jinja_env = datasette.jinja_env
|
||||||
self.executor = datasette.executor
|
|
||||||
|
|
||||||
async def get(self, request, as_json):
|
async def get(self, request, as_json):
|
||||||
databases = []
|
databases = []
|
||||||
|
|
@ -441,7 +413,7 @@ class RowTableShared(BaseView):
|
||||||
placeholders=', '.join(['?'] * len(ids_to_lookup)),
|
placeholders=', '.join(['?'] * len(ids_to_lookup)),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
results = await self.execute(database, sql, list(set(ids_to_lookup)))
|
results = await self.ds.execute(database, sql, list(set(ids_to_lookup)))
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
# Probably hit the timelimit
|
# Probably hit the timelimit
|
||||||
pass
|
pass
|
||||||
|
|
@ -504,17 +476,17 @@ class TableView(RowTableShared):
|
||||||
if canned_query is not None:
|
if canned_query is not None:
|
||||||
return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table)
|
return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table)
|
||||||
pks = await self.pks_for_table(name, table)
|
pks = await self.pks_for_table(name, table)
|
||||||
is_view = bool(list(await self.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
|
is_view = bool(list(await self.ds.execute(name, "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", {
|
||||||
'n': table,
|
'n': table,
|
||||||
}))[0][0])
|
}))[0][0])
|
||||||
view_definition = None
|
view_definition = None
|
||||||
table_definition = None
|
table_definition = None
|
||||||
if is_view:
|
if is_view:
|
||||||
view_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
|
view_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="view"', {
|
||||||
'n': table,
|
'n': table,
|
||||||
}))[0][0]
|
}))[0][0]
|
||||||
else:
|
else:
|
||||||
table_definition = list(await self.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
|
table_definition = list(await self.ds.execute(name, 'select sql from sqlite_master where name = :n and type="table"', {
|
||||||
'n': table,
|
'n': table,
|
||||||
}))[0][0]
|
}))[0][0]
|
||||||
use_rowid = not pks and not is_view
|
use_rowid = not pks and not is_view
|
||||||
|
|
@ -562,7 +534,7 @@ class TableView(RowTableShared):
|
||||||
# _search support:
|
# _search support:
|
||||||
fts_table = None
|
fts_table = None
|
||||||
fts_sql = detect_fts_sql(table)
|
fts_sql = detect_fts_sql(table)
|
||||||
fts_rows = list(await self.execute(name, fts_sql))
|
fts_rows = list(await self.ds.execute(name, fts_sql))
|
||||||
if fts_rows:
|
if fts_rows:
|
||||||
fts_table = fts_rows[0][0]
|
fts_table = fts_rows[0][0]
|
||||||
|
|
||||||
|
|
@ -638,7 +610,7 @@ class TableView(RowTableShared):
|
||||||
if request.raw_args.get('_sql_time_limit_ms'):
|
if request.raw_args.get('_sql_time_limit_ms'):
|
||||||
extra_args['custom_time_limit'] = int(request.raw_args['_sql_time_limit_ms'])
|
extra_args['custom_time_limit'] = int(request.raw_args['_sql_time_limit_ms'])
|
||||||
|
|
||||||
rows, truncated, description = await self.execute(
|
rows, truncated, description = await self.ds.execute(
|
||||||
name, sql, params, truncate=True, **extra_args
|
name, sql, params, truncate=True, **extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -678,7 +650,7 @@ class TableView(RowTableShared):
|
||||||
# Attempt a full count, if we can do it in < X ms
|
# Attempt a full count, if we can do it in < X ms
|
||||||
if count_sql:
|
if count_sql:
|
||||||
try:
|
try:
|
||||||
count_rows = list(await self.execute(name, count_sql, params))
|
count_rows = list(await self.ds.execute(name, count_sql, params))
|
||||||
filtered_table_rows = count_rows[0][0]
|
filtered_table_rows = count_rows[0][0]
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
# Almost certainly hit the timeout
|
# Almost certainly hit the timeout
|
||||||
|
|
@ -689,7 +661,7 @@ class TableView(RowTableShared):
|
||||||
|
|
||||||
async def extra_template():
|
async def extra_template():
|
||||||
display_columns, display_rows = await self.display_columns_and_rows(
|
display_columns, display_rows = await self.display_columns_and_rows(
|
||||||
name, table, description, rows, link_column=not is_view, expand_foreign_keys=True
|
name, table, description, rows, link_column=not is_view, expand_foreign_keys=not is_view
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
'database_hash': hash,
|
'database_hash': hash,
|
||||||
|
|
@ -755,8 +727,8 @@ class RowView(RowTableShared):
|
||||||
params = {}
|
params = {}
|
||||||
for i, pk_value in enumerate(pk_values):
|
for i, pk_value in enumerate(pk_values):
|
||||||
params['p{}'.format(i)] = pk_value
|
params['p{}'.format(i)] = pk_value
|
||||||
# rows, truncated, description = await self.execute(name, sql, params, truncate=True)
|
# rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
|
||||||
rows, truncated, description = await self.execute(name, sql, params, truncate=True)
|
rows, truncated, description = await self.ds.execute(name, sql, params, truncate=True)
|
||||||
columns = [r[0] for r in description]
|
columns = [r[0] for r in description]
|
||||||
rows = list(rows)
|
rows = list(rows)
|
||||||
if not rows:
|
if not rows:
|
||||||
|
|
@ -813,7 +785,7 @@ class RowView(RowTableShared):
|
||||||
for fk in foreign_keys
|
for fk in foreign_keys
|
||||||
])
|
])
|
||||||
try:
|
try:
|
||||||
rows = list(await self.execute(name, sql, {'id': pk_values[0]}))
|
rows = list(await self.ds.execute(name, sql, {'id': pk_values[0]}))
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
# Almost certainly hit the timeout
|
# Almost certainly hit the timeout
|
||||||
return []
|
return []
|
||||||
|
|
@ -826,7 +798,8 @@ class RowView(RowTableShared):
|
||||||
foreign_key_tables = []
|
foreign_key_tables = []
|
||||||
for fk in foreign_keys:
|
for fk in foreign_keys:
|
||||||
count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0
|
count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0
|
||||||
foreign_key_tables.append({**fk, **{'count': count}})
|
if count:
|
||||||
|
foreign_key_tables.append({**fk, **{'count': count}})
|
||||||
return foreign_key_tables
|
return foreign_key_tables
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -835,8 +808,8 @@ class Datasette:
|
||||||
self, files, num_threads=3, cache_headers=True, page_size=100,
|
self, files, num_threads=3, cache_headers=True, page_size=100,
|
||||||
max_returned_rows=1000, sql_time_limit_ms=1000, cors=False,
|
max_returned_rows=1000, sql_time_limit_ms=1000, cors=False,
|
||||||
inspect_data=None, metadata=None, sqlite_extensions=None,
|
inspect_data=None, metadata=None, sqlite_extensions=None,
|
||||||
template_dir=None, static_mounts=None):
|
template_dir=None, static_mounts=None, memory=None):
|
||||||
self.files = files
|
self.files = list(files)
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.executor = futures.ThreadPoolExecutor(
|
self.executor = futures.ThreadPoolExecutor(
|
||||||
max_workers=num_threads
|
max_workers=num_threads
|
||||||
|
|
@ -852,6 +825,90 @@ class Datasette:
|
||||||
self.sqlite_extensions = sqlite_extensions or []
|
self.sqlite_extensions = sqlite_extensions or []
|
||||||
self.template_dir = template_dir
|
self.template_dir = template_dir
|
||||||
self.static_mounts = static_mounts or []
|
self.static_mounts = static_mounts or []
|
||||||
|
self.memory = memory or []
|
||||||
|
for filepath in self.memory:
|
||||||
|
if filepath not in self.files:
|
||||||
|
self.files.append(filepath)
|
||||||
|
|
||||||
|
async def initialize_memory_connections(self):
|
||||||
|
print('initialize_memory_connections')
|
||||||
|
for name, info in self.inspect().items():
|
||||||
|
print(name, info['memory'])
|
||||||
|
if info['memory']:
|
||||||
|
await self.execute(name, 'select 1')
|
||||||
|
|
||||||
|
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
|
||||||
|
"""Executes sql against db_name in a thread"""
|
||||||
|
def sql_operation_in_thread():
|
||||||
|
conn = getattr(connections, db_name, None)
|
||||||
|
if not conn:
|
||||||
|
info = self.inspect()[db_name]
|
||||||
|
if info['file'] in self.memory:
|
||||||
|
# Copy file into an in-memory database
|
||||||
|
memory_conn = sqlite3.connect(
|
||||||
|
'file:{}?mode=memory&cache=shared'.format(db_name),
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False,
|
||||||
|
)
|
||||||
|
self.prepare_connection(memory_conn)
|
||||||
|
# Do we need to copy data across?
|
||||||
|
if not memory_conn.execute("select count(*) from sqlite_master where type='table'").fetchone()[0]:
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
'file:{}?immutable=1'.format(info['file']),
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False,
|
||||||
|
)
|
||||||
|
self.prepare_connection(conn)
|
||||||
|
pages_todo = conn.execute('PRAGMA page_count').fetchone()[0]
|
||||||
|
print('Here we go...')
|
||||||
|
i = 0
|
||||||
|
for line in conn.iterdump():
|
||||||
|
memory_conn.execute(line)
|
||||||
|
i += 1
|
||||||
|
if i % 10000 == 0:
|
||||||
|
pages_done = memory_conn.execute('PRAGMA page_count').fetchone()[0]
|
||||||
|
print('Done {}/{} {:.2f}%'.format(
|
||||||
|
pages_done, pages_todo, (pages_done / pages_todo * 100)
|
||||||
|
))
|
||||||
|
conn.close()
|
||||||
|
conn = memory_conn
|
||||||
|
else:
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
'file:{}?immutable=1'.format(info['file']),
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False,
|
||||||
|
)
|
||||||
|
self.prepare_connection(conn)
|
||||||
|
setattr(connections, db_name, conn)
|
||||||
|
|
||||||
|
time_limit_ms = self.sql_time_limit_ms
|
||||||
|
if custom_time_limit and custom_time_limit < self.sql_time_limit_ms:
|
||||||
|
time_limit_ms = custom_time_limit
|
||||||
|
|
||||||
|
with sqlite_timelimit(conn, time_limit_ms):
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql, params or {})
|
||||||
|
if self.max_returned_rows and truncate:
|
||||||
|
rows = cursor.fetchmany(self.max_returned_rows + 1)
|
||||||
|
truncated = len(rows) > self.max_returned_rows
|
||||||
|
rows = rows[:self.max_returned_rows]
|
||||||
|
else:
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
truncated = False
|
||||||
|
except Exception:
|
||||||
|
print('ERROR: conn={}, sql = {}, params = {}'.format(
|
||||||
|
conn, repr(sql), params
|
||||||
|
))
|
||||||
|
raise
|
||||||
|
if truncate:
|
||||||
|
return rows, truncated, cursor.description
|
||||||
|
else:
|
||||||
|
return rows
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
self.executor, sql_operation_in_thread
|
||||||
|
)
|
||||||
|
|
||||||
def app_css_hash(self):
|
def app_css_hash(self):
|
||||||
if not hasattr(self, '_app_css_hash'):
|
if not hasattr(self, '_app_css_hash'):
|
||||||
|
|
@ -975,6 +1032,10 @@ class Datasette:
|
||||||
'views': views,
|
'views': views,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
# Annotate in-memory databases
|
||||||
|
memory_names = {Path(filename).stem for filename in self.memory}
|
||||||
|
for name, info in self._inspect.items():
|
||||||
|
info['memory'] = name in memory_names
|
||||||
return self._inspect
|
return self._inspect
|
||||||
|
|
||||||
def app(self):
|
def app(self):
|
||||||
|
|
@ -1020,4 +1081,6 @@ class Datasette:
|
||||||
RowView.as_view(self),
|
RowView.as_view(self),
|
||||||
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(\.jsono?)?$>'
|
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(\.jsono?)?$>'
|
||||||
)
|
)
|
||||||
|
if self.memory:
|
||||||
|
app.add_task(self.initialize_memory_connections)
|
||||||
return app
|
return app
|
||||||
|
|
|
||||||
|
|
@ -233,7 +233,8 @@ def package(files, tag, metadata, extra_options, branch, template_dir, static, *
|
||||||
@click.option('-m', '--metadata', type=click.File(mode='r'), help='Path to JSON file containing license/source metadata')
|
@click.option('-m', '--metadata', type=click.File(mode='r'), help='Path to JSON file containing license/source metadata')
|
||||||
@click.option('--template-dir', type=click.Path(exists=True, file_okay=False, dir_okay=True), help='Path to directory containing custom templates')
|
@click.option('--template-dir', type=click.Path(exists=True, file_okay=False, dir_okay=True), help='Path to directory containing custom templates')
|
||||||
@click.option('--static', type=StaticMount(), help='mountpoint:path-to-directory for serving static files', multiple=True)
|
@click.option('--static', type=StaticMount(), help='mountpoint:path-to-directory for serving static files', multiple=True)
|
||||||
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static):
|
@click.option('--memory', type=click.Path(exists=True), help='database files to load into memory', multiple=True)
|
||||||
|
def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows, sql_time_limit_ms, sqlite_extensions, inspect_file, metadata, template_dir, static, memory):
|
||||||
"""Serve up specified SQLite database files with a web UI"""
|
"""Serve up specified SQLite database files with a web UI"""
|
||||||
if reload:
|
if reload:
|
||||||
import hupper
|
import hupper
|
||||||
|
|
@ -262,6 +263,7 @@ def serve(files, host, port, debug, reload, cors, page_size, max_returned_rows,
|
||||||
sqlite_extensions=sqlite_extensions,
|
sqlite_extensions=sqlite_extensions,
|
||||||
template_dir=template_dir,
|
template_dir=template_dir,
|
||||||
static_mounts=static,
|
static_mounts=static,
|
||||||
|
memory=memory,
|
||||||
)
|
)
|
||||||
# Force initial hashing/table counting
|
# Force initial hashing/table counting
|
||||||
ds.inspect()
|
ds.inspect()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue