diff --git a/datasite/app.py b/datasite/app.py index ac74c534..44ae9c51 100644 --- a/datasite/app.py +++ b/datasite/app.py @@ -8,6 +8,9 @@ import sqlite3 from contextlib import contextmanager from pathlib import Path from functools import wraps +from concurrent import futures +import asyncio +import threading import urllib.parse import json import base64 @@ -22,19 +25,7 @@ DB_GLOBS = ('*.db', '*.sqlite', '*.sqlite3') HASH_BLOCK_SIZE = 1024 * 1024 SQL_TIME_LIMIT_MS = 1000 -conns = {} - - -def get_conn(name): - if name not in conns: - info = ensure_build_metadata()[name] - conns[name] = sqlite3.connect( - 'file:{}?immutable=1'.format(info['file']), - uri=True - ) - conns[name].row_factory = sqlite3.Row - conns[name].text_factory = lambda x: str(x, 'utf-8', 'replace') - return conns[name] +connections = threading.local() def ensure_build_metadata(regenerate=False): @@ -80,8 +71,9 @@ def ensure_build_metadata(regenerate=False): class BaseView(HTTPMethodView): template = None - def __init__(self, jinja): + def __init__(self, jinja, executor): self.jinja = jinja + self.executor = executor def redirect(self, request, path): if request.query_string: @@ -92,6 +84,40 @@ class BaseView(HTTPMethodView): r.headers['Link'] = '<{}>; rel=preload'.format(path) return r + async def pks_for_table(self, name, table): + rows = [ + row for row in await self.execute( + name, + 'PRAGMA table_info("{}")'.format(table) + ) + if row[-1] + ] + rows.sort(key=lambda row: row[-1]) + return [str(r[1]) for r in rows] + + async def execute(self, db_name, sql): + """Executes sql against db_name in a thread""" + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = ensure_build_metadata()[db_name] + conn = sqlite3.connect( + 'file:{}?immutable=1'.format(info['file']), + uri=True, + check_same_thread=False, + ) + conn.row_factory = sqlite3.Row + conn.text_factory = lambda x: str(x, 'utf-8', 'replace') + setattr(connections, db_name, conn) + + with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): + rows = conn.execute(sql) + return rows + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) + async def get(self, request, db_name, **kwargs): name, hash, should_redirect = resolve_db_name(db_name, **kwargs) if should_redirect: @@ -106,7 +132,7 @@ class BaseView(HTTPMethodView): extra_template_data = {} start = time.time() try: - data, extra_template_data = self.data( + data, extra_template_data = await self.data( request, name, hash, **kwargs ) except sqlite3.OperationalError as e: @@ -154,8 +180,9 @@ class BaseView(HTTPMethodView): class IndexView(HTTPMethodView): - def __init__(self, jinja): + def __init__(self, jinja, executor): self.jinja = jinja + self.executor = executor async def get(self, request): databases = [] @@ -188,11 +215,9 @@ async def favicon(request): class DatabaseView(BaseView): template = 'database.html' - def data(self, request, name, hash): - conn = get_conn(name) + async def data(self, request, name, hash): sql = request.args.get('sql') or 'select * from sqlite_master' - with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): - rows = conn.execute(sql) + rows = await self.execute(name, sql) columns = [r[0] for r in rows.description] return { 'database': name, @@ -216,8 +241,7 @@ class DatabaseDownload(BaseView): class TableView(BaseView): template = 'table.html' - def data(self, request, name, hash, table): - conn = get_conn(name) + async def data(self, request, name, hash, table): table = urllib.parse.unquote_plus(table) if request.args: where_clause, params = build_where_clause(request.args) @@ -228,12 +252,11 @@ class TableView(BaseView): sql = 'select * from "{}" limit 50'.format(table) params = [] - with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): - rows = conn.execute(sql, params) + rows = await self.execute(name, sql) columns = [r[0] for r in rows.description] rows = list(rows) - pks = pks_for_table(conn, table) + pks = await self.pks_for_table(name, table) info = ensure_build_metadata() total_rows = info[name]['tables'].get(table) return { @@ -252,11 +275,10 @@ class TableView(BaseView): class RowView(BaseView): template = 'row.html' - def data(self, request, name, hash, table, pk_path): - conn = get_conn(name) + async def data(self, request, name, hash, table, pk_path): table = urllib.parse.unquote_plus(table) pk_values = compound_pks_from_path(pk_path) - pks = pks_for_table(conn, table) + pks = await self.pks_for_table(name, table) wheres = [ '"{}"=?'.format(pk) for pk in pks @@ -264,9 +286,8 @@ class RowView(BaseView): sql = 'select * from "{}" where {}'.format( table, ' AND '.join(wheres) ) - rows = conn.execute(sql, pk_values) + rows = await self.execute(name, sql) columns = [r[0] for r in rows.description] - pks = pks_for_table(conn, table) rows = list(rows) if not rows: raise NotFound('Record not found: {}'.format(pk_values)) @@ -322,17 +343,6 @@ def compound_pks_from_path(path): ] -def pks_for_table(conn, table): - rows = [ - row for row in conn.execute( - 'PRAGMA table_info("{}")'.format(table) - ).fetchall() - if row[-1] - ] - rows.sort(key=lambda row: row[-1]) - return [str(r[1]) for r in rows] - - def path_from_row_pks(row, pks): if not pks: return '' @@ -410,7 +420,7 @@ def sqlite_timelimit(conn, ms): conn.set_progress_handler(None, 10000) -def app_factory(files): +def app_factory(files, num_threads=3): app = Sanic(__name__) jinja = SanicJinja2( app, @@ -418,22 +428,23 @@ def app_factory(files): str(app_root / 'datasite' / 'templates') ]) ) - app.add_route(IndexView.as_view(jinja), '/') + executor = futures.ThreadPoolExecutor(max_workers=num_threads) + app.add_route(IndexView.as_view(jinja, executor), '/') app.add_route(favicon, '/favicon.ico') app.add_route( - DatabaseView.as_view(jinja), + DatabaseView.as_view(jinja, executor), '/' ) app.add_route( - DatabaseDownload.as_view(jinja), + DatabaseDownload.as_view(jinja, executor), '/' ) app.add_route( - TableView.as_view(jinja), + TableView.as_view(jinja, executor), '//' ) app.add_route( - RowView.as_view(jinja), + RowView.as_view(jinja, executor), '///' ) return app