diff --git a/datasette/app.py b/datasette/app.py index cb0f462a..853f24a0 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -108,6 +108,12 @@ class BaseView(HTTPMethodView): return name, expected, should_redirect return name, expected, None + def prepare_connection(self, conn): + conn.row_factory = sqlite3.Row + conn.text_factory = lambda x: str(x, 'utf-8', 'replace') + for name, num_args, func in self.ds.sqlite_functions: + conn.create_function(name, num_args, func) + async def execute(self, db_name, sql, params=None, truncate=False): """Executes sql against db_name in a thread""" def sql_operation_in_thread(): @@ -119,8 +125,7 @@ class BaseView(HTTPMethodView): uri=True, check_same_thread=False, ) - conn.row_factory = sqlite3.Row - conn.text_factory = lambda x: str(x, 'utf-8', 'replace') + self.prepare_connection(conn) setattr(connections, db_name, conn) with sqlite_timelimit(conn, self.ds.sql_time_limit_ms): @@ -525,6 +530,7 @@ class Datasette: self.cors = cors self._inspect = inspect_data self.metadata = metadata or {} + self.sqlite_functions = [] def inspect(self): if not self._inspect: diff --git a/datasette/utils.py b/datasette/utils.py index 87f95811..c41d9f42 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -85,13 +85,21 @@ class CustomJSONEncoder(json.JSONEncoder): @contextmanager def sqlite_timelimit(conn, ms): deadline = time.time() + (ms / 1000) + # n is the number of SQLite virtual machine instructions that will be + # executed between each check. It's hard to know what to pick here. + # After some experimentation, I've decided to go with 1000 by default and + # 1 for time limits that are less than 50ms + n = 1000 + if ms < 50: + n = 1 def handler(): if time.time() >= deadline: return 1 - conn.set_progress_handler(handler, 10000) + + conn.set_progress_handler(handler, n) yield - conn.set_progress_handler(None, 10000) + conn.set_progress_handler(None, n) class InvalidSql(Exception): diff --git a/tests/test_app.py b/tests/test_app.py index a4405c8e..ff910c3c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -3,6 +3,7 @@ import os import pytest import sqlite3 import tempfile +import time @pytest.fixture(scope='module') @@ -12,7 +13,16 @@ def app_client(): conn = sqlite3.connect(filepath) conn.executescript(TABLES) os.chdir(os.path.dirname(filepath)) - yield Datasette([filepath], page_size=50, max_returned_rows=100).app().test_client + ds = Datasette( + [filepath], + page_size=50, + max_returned_rows=100, + sql_time_limit_ms=20, + ) + ds.sqlite_functions.append( + ('sleep', 1, lambda n: time.sleep(float(n))), + ) + yield ds.app().test_client def test_homepage(app_client): @@ -83,6 +93,14 @@ def test_custom_sql(app_client): assert not data['truncated'] +def test_sql_time_limit(app_client): + _, response = app_client.get( + '/test_tables.jsono?sql=select+sleep(0.5)' + ) + assert 400 == response.status + assert 'interrupted' == response.json['error'] + + def test_invalid_custom_sql(app_client): _, response = app_client.get( '/test_tables?sql=.schema'