diff --git a/datasette/app.py b/datasette/app.py index 86a466d7..7e11f3c1 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -5,18 +5,23 @@ from sanic.views import HTTPMethodView from sanic_jinja2 import SanicJinja2 from jinja2 import FileSystemLoader import sqlite3 -from contextlib import contextmanager from pathlib import Path -from functools import wraps from concurrent import futures import asyncio import threading import urllib.parse import json -import base64 import hashlib -import sys import time +from .utils import ( + build_where_clause, + CustomJSONEncoder, + InvalidSql, + path_from_row_pks, + compound_pks_from_path, + sqlite_timelimit, + validate_sql_select, +) app_root = Path(__file__).parent.parent @@ -373,93 +378,6 @@ def resolve_db_name(files, db_name, **kwargs): return name, expected, None -def compound_pks_from_path(path): - return [ - urllib.parse.unquote_plus(b) for b in path.split(',') - ] - - -def path_from_row_pks(row, pks, use_rowid): - if use_rowid: - return urllib.parse.quote_plus(str(row['rowid'])) - bits = [] - for pk in pks: - bits.append( - urllib.parse.quote_plus(str(row[pk])) - ) - return ','.join(bits) - - -def build_where_clause(args): - sql_bits = [] - params = {} - for i, (key, values) in enumerate(sorted(args.items())): - if '__' in key: - column, lookup = key.rsplit('__', 1) - else: - column = key - lookup = 'exact' - template = { - 'exact': '"{}" = :{}', - 'contains': '"{}" like :{}', - 'endswith': '"{}" like :{}', - 'startswith': '"{}" like :{}', - 'gt': '"{}" > :{}', - 'gte': '"{}" >= :{}', - 'lt': '"{}" < :{}', - 'lte': '"{}" <= :{}', - 'glob': '"{}" glob :{}', - 'like': '"{}" like :{}', - }[lookup] - numeric_operators = {'gt', 'gte', 'lt', 'lte'} - value = values[0] - value_convert = { - 'contains': lambda s: '%{}%'.format(s), - 'endswith': lambda s: '%{}'.format(s), - 'startswith': lambda s: '{}%'.format(s), - }.get(lookup, lambda s: s) - converted = value_convert(value) - if lookup in numeric_operators and converted.isdigit(): - converted = int(converted) - param_id = 'p{}'.format(i) - sql_bits.append( - template.format(column, param_id) - ) - params[param_id] = converted - where_clause = ' and '.join(sql_bits) - return where_clause, params - - -class CustomJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, sqlite3.Row): - return tuple(obj) - if isinstance(obj, sqlite3.Cursor): - return list(obj) - if isinstance(obj, bytes): - # Does it encode to utf8? - try: - return obj.decode('utf8') - except UnicodeDecodeError: - return { - '$base64': True, - 'encoded': base64.b64encode(obj).decode('latin1'), - } - return json.JSONEncoder.default(self, obj) - - -@contextmanager -def sqlite_timelimit(conn, ms): - deadline = time.time() + (ms / 1000) - - def handler(): - if time.time() >= deadline: - return 1 - conn.set_progress_handler(handler, 10000) - yield - conn.set_progress_handler(None, 10000) - - class Datasette: def __init__(self, files, num_threads=3): self.files = files @@ -497,15 +415,3 @@ class Datasette: '///' ) return app - - -class InvalidSql(Exception): - pass - - -def validate_sql_select(sql): - sql = sql.strip().lower() - if not sql.startswith('select '): - raise InvalidSql('Statement must begin with SELECT') - if 'pragma' in sql: - raise InvalidSql('Statement may not contain PRAGMA') diff --git a/datasette/utils.py b/datasette/utils.py new file mode 100644 index 00000000..000f86d2 --- /dev/null +++ b/datasette/utils.py @@ -0,0 +1,105 @@ +from contextlib import contextmanager +import base64 +import json +import sqlite3 +import time +import urllib + + +def compound_pks_from_path(path): + return [ + urllib.parse.unquote_plus(b) for b in path.split(',') + ] + + +def path_from_row_pks(row, pks, use_rowid): + if use_rowid: + return urllib.parse.quote_plus(str(row['rowid'])) + bits = [] + for pk in pks: + bits.append( + urllib.parse.quote_plus(str(row[pk])) + ) + return ','.join(bits) + + +def build_where_clause(args): + sql_bits = [] + params = {} + for i, (key, values) in enumerate(sorted(args.items())): + if '__' in key: + column, lookup = key.rsplit('__', 1) + else: + column = key + lookup = 'exact' + template = { + 'exact': '"{}" = :{}', + 'contains': '"{}" like :{}', + 'endswith': '"{}" like :{}', + 'startswith': '"{}" like :{}', + 'gt': '"{}" > :{}', + 'gte': '"{}" >= :{}', + 'lt': '"{}" < :{}', + 'lte': '"{}" <= :{}', + 'glob': '"{}" glob :{}', + 'like': '"{}" like :{}', + }[lookup] + numeric_operators = {'gt', 'gte', 'lt', 'lte'} + value = values[0] + value_convert = { + 'contains': lambda s: '%{}%'.format(s), + 'endswith': lambda s: '%{}'.format(s), + 'startswith': lambda s: '{}%'.format(s), + }.get(lookup, lambda s: s) + converted = value_convert(value) + if lookup in numeric_operators and converted.isdigit(): + converted = int(converted) + param_id = 'p{}'.format(i) + sql_bits.append( + template.format(column, param_id) + ) + params[param_id] = converted + where_clause = ' and '.join(sql_bits) + return where_clause, params + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, sqlite3.Row): + return tuple(obj) + if isinstance(obj, sqlite3.Cursor): + return list(obj) + if isinstance(obj, bytes): + # Does it encode to utf8? + try: + return obj.decode('utf8') + except UnicodeDecodeError: + return { + '$base64': True, + 'encoded': base64.b64encode(obj).decode('latin1'), + } + return json.JSONEncoder.default(self, obj) + + +@contextmanager +def sqlite_timelimit(conn, ms): + deadline = time.time() + (ms / 1000) + + def handler(): + if time.time() >= deadline: + return 1 + conn.set_progress_handler(handler, 10000) + yield + conn.set_progress_handler(None, 10000) + + +class InvalidSql(Exception): + pass + + +def validate_sql_select(sql): + sql = sql.strip().lower() + if not sql.startswith('select '): + raise InvalidSql('Statement must begin with SELECT') + if 'pragma' in sql: + raise InvalidSql('Statement may not contain PRAGMA') diff --git a/tests/test_helpers.py b/tests/test_utils.py similarity index 87% rename from tests/test_helpers.py rename to tests/test_utils.py index ee798921..5a3f26a5 100644 --- a/tests/test_helpers.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ Tests for various datasette helper functions. """ -from datasette import app +from datasette import utils import pytest import json @@ -15,7 +15,7 @@ import json ('123%2F433%2F112', ['123/433/112']), ]) def test_compound_pks_from_path(path, expected): - assert expected == app.compound_pks_from_path(path) + assert expected == utils.compound_pks_from_path(path) @pytest.mark.parametrize('row,pks,expected_path', [ @@ -24,7 +24,7 @@ def test_compound_pks_from_path(path, expected): ({'A': 123}, ['A'], '123'), ]) def test_path_from_row_pks(row, pks, expected_path): - actual_path = app.path_from_row_pks(row, pks, False) + actual_path = utils.path_from_row_pks(row, pks, False) assert expected_path == actual_path @@ -40,7 +40,7 @@ def test_path_from_row_pks(row, pks, expected_path): def test_custom_json_encoder(obj, expected): actual = json.dumps( obj, - cls=app.CustomJSONEncoder, + cls=utils.CustomJSONEncoder, sort_keys=True ) assert expected == actual @@ -90,7 +90,7 @@ def test_custom_json_encoder(obj, expected): ), ]) def test_build_where(args, expected_where, expected_params): - actual_where, actual_params = app.build_where_clause(args) + actual_where, actual_params = utils.build_where_clause(args) assert expected_where == actual_where assert { 'p{}'.format(i): param @@ -104,8 +104,8 @@ def test_build_where(args, expected_where, expected_params): "SELECT * FROM pragma_index_info('idx52')", ]) def test_validate_sql_select_bad(bad_sql): - with pytest.raises(app.InvalidSql): - app.validate_sql_select(bad_sql) + with pytest.raises(utils.InvalidSql): + utils.validate_sql_select(bad_sql) @pytest.mark.parametrize('good_sql', [ @@ -114,4 +114,4 @@ def test_validate_sql_select_bad(bad_sql): 'select 1 + 1', ]) def test_validate_sql_select_good(good_sql): - app.validate_sql_select(good_sql) + utils.validate_sql_select(good_sql)