Moved all SQLite queries to threads

SQLite operations are blocking, but we're running everything in Sanic, an
asyncio web framework, so blocking operations are bad - a long-running DB
operation could hold up the entire server.

Instead, I've moved all SQLite operations into threads. These are managed by a
concurrent.futures ThreadPoolExecutor. This means I can run up to X queries in
parallel, and I can continue to queue up additional incoming HTTP traffic
while the threadpool is busy.

Each thread is responsible for managing its own SQLite connections - one per
database. These are cached in a threadlocal.

Since we are working with immutable, read-only SQLite databases it should be
safe to share SQLite objects across threads. On this assumption I'm using the
check_same_thread=False option. Opening a database connection looks like this:

    conn = sqlite3.connect(
        'file:filename.db?immutable=1',
        uri=True,
        check_same_thread=False,
    )

The following articles were helpful in figuring this out:

* https://pymotw.com/3/asyncio/executors.html
* https://marlinux.wordpress.com/2017/05/19/python-3-6-asyncio-sqlalchemy/

Closes #45. Refs #38.
This commit is contained in:
Simon Willison 2017-11-04 19:21:44 -07:00
commit 31b21f5c5e

View file

@ -8,6 +8,9 @@ import sqlite3
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from functools import wraps from functools import wraps
from concurrent import futures
import asyncio
import threading
import urllib.parse import urllib.parse
import json import json
import base64 import base64
@ -22,19 +25,7 @@ DB_GLOBS = ('*.db', '*.sqlite', '*.sqlite3')
HASH_BLOCK_SIZE = 1024 * 1024 HASH_BLOCK_SIZE = 1024 * 1024
SQL_TIME_LIMIT_MS = 1000 SQL_TIME_LIMIT_MS = 1000
conns = {} connections = threading.local()
def get_conn(name):
if name not in conns:
info = ensure_build_metadata()[name]
conns[name] = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True
)
conns[name].row_factory = sqlite3.Row
conns[name].text_factory = lambda x: str(x, 'utf-8', 'replace')
return conns[name]
def ensure_build_metadata(regenerate=False): def ensure_build_metadata(regenerate=False):
@ -80,8 +71,9 @@ def ensure_build_metadata(regenerate=False):
class BaseView(HTTPMethodView): class BaseView(HTTPMethodView):
template = None template = None
def __init__(self, jinja): def __init__(self, jinja, executor):
self.jinja = jinja self.jinja = jinja
self.executor = executor
def redirect(self, request, path): def redirect(self, request, path):
if request.query_string: if request.query_string:
@ -92,6 +84,40 @@ class BaseView(HTTPMethodView):
r.headers['Link'] = '<{}>; rel=preload'.format(path) r.headers['Link'] = '<{}>; rel=preload'.format(path)
return r return r
async def pks_for_table(self, name, table):
rows = [
row for row in await self.execute(
name,
'PRAGMA table_info("{}")'.format(table)
)
if row[-1]
]
rows.sort(key=lambda row: row[-1])
return [str(r[1]) for r in rows]
async def execute(self, db_name, sql):
"""Executes sql against db_name in a thread"""
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = ensure_build_metadata()[db_name]
conn = sqlite3.connect(
'file:{}?immutable=1'.format(info['file']),
uri=True,
check_same_thread=False,
)
conn.row_factory = sqlite3.Row
conn.text_factory = lambda x: str(x, 'utf-8', 'replace')
setattr(connections, db_name, conn)
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS):
rows = conn.execute(sql)
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
async def get(self, request, db_name, **kwargs): async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = resolve_db_name(db_name, **kwargs) name, hash, should_redirect = resolve_db_name(db_name, **kwargs)
if should_redirect: if should_redirect:
@ -106,7 +132,7 @@ class BaseView(HTTPMethodView):
extra_template_data = {} extra_template_data = {}
start = time.time() start = time.time()
try: try:
data, extra_template_data = self.data( data, extra_template_data = await self.data(
request, name, hash, **kwargs request, name, hash, **kwargs
) )
except sqlite3.OperationalError as e: except sqlite3.OperationalError as e:
@ -154,8 +180,9 @@ class BaseView(HTTPMethodView):
class IndexView(HTTPMethodView): class IndexView(HTTPMethodView):
def __init__(self, jinja): def __init__(self, jinja, executor):
self.jinja = jinja self.jinja = jinja
self.executor = executor
async def get(self, request): async def get(self, request):
databases = [] databases = []
@ -188,11 +215,9 @@ async def favicon(request):
class DatabaseView(BaseView): class DatabaseView(BaseView):
template = 'database.html' template = 'database.html'
def data(self, request, name, hash): async def data(self, request, name, hash):
conn = get_conn(name)
sql = request.args.get('sql') or 'select * from sqlite_master' sql = request.args.get('sql') or 'select * from sqlite_master'
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): rows = await self.execute(name, sql)
rows = conn.execute(sql)
columns = [r[0] for r in rows.description] columns = [r[0] for r in rows.description]
return { return {
'database': name, 'database': name,
@ -216,8 +241,7 @@ class DatabaseDownload(BaseView):
class TableView(BaseView): class TableView(BaseView):
template = 'table.html' template = 'table.html'
def data(self, request, name, hash, table): async def data(self, request, name, hash, table):
conn = get_conn(name)
table = urllib.parse.unquote_plus(table) table = urllib.parse.unquote_plus(table)
if request.args: if request.args:
where_clause, params = build_where_clause(request.args) where_clause, params = build_where_clause(request.args)
@ -228,12 +252,11 @@ class TableView(BaseView):
sql = 'select * from "{}" limit 50'.format(table) sql = 'select * from "{}" limit 50'.format(table)
params = [] params = []
with sqlite_timelimit(conn, SQL_TIME_LIMIT_MS): rows = await self.execute(name, sql)
rows = conn.execute(sql, params)
columns = [r[0] for r in rows.description] columns = [r[0] for r in rows.description]
rows = list(rows) rows = list(rows)
pks = pks_for_table(conn, table) pks = await self.pks_for_table(name, table)
info = ensure_build_metadata() info = ensure_build_metadata()
total_rows = info[name]['tables'].get(table) total_rows = info[name]['tables'].get(table)
return { return {
@ -252,11 +275,10 @@ class TableView(BaseView):
class RowView(BaseView): class RowView(BaseView):
template = 'row.html' template = 'row.html'
def data(self, request, name, hash, table, pk_path): async def data(self, request, name, hash, table, pk_path):
conn = get_conn(name)
table = urllib.parse.unquote_plus(table) table = urllib.parse.unquote_plus(table)
pk_values = compound_pks_from_path(pk_path) pk_values = compound_pks_from_path(pk_path)
pks = pks_for_table(conn, table) pks = await self.pks_for_table(name, table)
wheres = [ wheres = [
'"{}"=?'.format(pk) '"{}"=?'.format(pk)
for pk in pks for pk in pks
@ -264,9 +286,8 @@ class RowView(BaseView):
sql = 'select * from "{}" where {}'.format( sql = 'select * from "{}" where {}'.format(
table, ' AND '.join(wheres) table, ' AND '.join(wheres)
) )
rows = conn.execute(sql, pk_values) rows = await self.execute(name, sql)
columns = [r[0] for r in rows.description] columns = [r[0] for r in rows.description]
pks = pks_for_table(conn, table)
rows = list(rows) rows = list(rows)
if not rows: if not rows:
raise NotFound('Record not found: {}'.format(pk_values)) raise NotFound('Record not found: {}'.format(pk_values))
@ -322,17 +343,6 @@ def compound_pks_from_path(path):
] ]
def pks_for_table(conn, table):
rows = [
row for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1]
]
rows.sort(key=lambda row: row[-1])
return [str(r[1]) for r in rows]
def path_from_row_pks(row, pks): def path_from_row_pks(row, pks):
if not pks: if not pks:
return '' return ''
@ -410,7 +420,7 @@ def sqlite_timelimit(conn, ms):
conn.set_progress_handler(None, 10000) conn.set_progress_handler(None, 10000)
def app_factory(files): def app_factory(files, num_threads=3):
app = Sanic(__name__) app = Sanic(__name__)
jinja = SanicJinja2( jinja = SanicJinja2(
app, app,
@ -418,22 +428,23 @@ def app_factory(files):
str(app_root / 'datasite' / 'templates') str(app_root / 'datasite' / 'templates')
]) ])
) )
app.add_route(IndexView.as_view(jinja), '/') executor = futures.ThreadPoolExecutor(max_workers=num_threads)
app.add_route(IndexView.as_view(jinja, executor), '/')
app.add_route(favicon, '/favicon.ico') app.add_route(favicon, '/favicon.ico')
app.add_route( app.add_route(
DatabaseView.as_view(jinja), DatabaseView.as_view(jinja, executor),
'/<db_name:[^/\.]+?><as_json:(.jsono?)?$>' '/<db_name:[^/\.]+?><as_json:(.jsono?)?$>'
) )
app.add_route( app.add_route(
DatabaseDownload.as_view(jinja), DatabaseDownload.as_view(jinja, executor),
'/<db_name:[^/]+?><as_db:(\.db)$>' '/<db_name:[^/]+?><as_db:(\.db)$>'
) )
app.add_route( app.add_route(
TableView.as_view(jinja), TableView.as_view(jinja, executor),
'/<db_name:[^/]+>/<table:[^/]+?><as_json:(.jsono?)?$>' '/<db_name:[^/]+>/<table:[^/]+?><as_json:(.jsono?)?$>'
) )
app.add_route( app.add_route(
RowView.as_view(jinja), RowView.as_view(jinja, executor),
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>' '/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
) )
return app return app