Refactored util functions into new utils module

This commit is contained in:
Simon Willison 2017-11-10 11:25:54 -08:00
commit a8a293cd71
4 changed files with 231 additions and 220 deletions

View file

@ -5,18 +5,23 @@ from sanic.views import HTTPMethodView
from sanic_jinja2 import SanicJinja2 from sanic_jinja2 import SanicJinja2
from jinja2 import FileSystemLoader from jinja2 import FileSystemLoader
import sqlite3 import sqlite3
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from functools import wraps
from concurrent import futures from concurrent import futures
import asyncio import asyncio
import threading import threading
import urllib.parse import urllib.parse
import json import json
import base64
import hashlib import hashlib
import sys
import time import time
from .utils import (
build_where_clause,
CustomJSONEncoder,
InvalidSql,
path_from_row_pks,
compound_pks_from_path,
sqlite_timelimit,
validate_sql_select,
)
app_root = Path(__file__).parent.parent app_root = Path(__file__).parent.parent
@ -373,93 +378,6 @@ def resolve_db_name(files, db_name, **kwargs):
return name, expected, None return name, expected, None
def compound_pks_from_path(path):
return [
urllib.parse.unquote_plus(b) for b in path.split(',')
]
def path_from_row_pks(row, pks, use_rowid):
if use_rowid:
return urllib.parse.quote_plus(str(row['rowid']))
bits = []
for pk in pks:
bits.append(
urllib.parse.quote_plus(str(row[pk]))
)
return ','.join(bits)
def build_where_clause(args):
sql_bits = []
params = {}
for i, (key, values) in enumerate(sorted(args.items())):
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
template = {
'exact': '"{}" = :{}',
'contains': '"{}" like :{}',
'endswith': '"{}" like :{}',
'startswith': '"{}" like :{}',
'gt': '"{}" > :{}',
'gte': '"{}" >= :{}',
'lt': '"{}" < :{}',
'lte': '"{}" <= :{}',
'glob': '"{}" glob :{}',
'like': '"{}" like :{}',
}[lookup]
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
value = values[0]
value_convert = {
'contains': lambda s: '%{}%'.format(s),
'endswith': lambda s: '%{}'.format(s),
'startswith': lambda s: '{}%'.format(s),
}.get(lookup, lambda s: s)
converted = value_convert(value)
if lookup in numeric_operators and converted.isdigit():
converted = int(converted)
param_id = 'p{}'.format(i)
sql_bits.append(
template.format(column, param_id)
)
params[param_id] = converted
where_clause = ' and '.join(sql_bits)
return where_clause, params
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sqlite3.Row):
return tuple(obj)
if isinstance(obj, sqlite3.Cursor):
return list(obj)
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
}
return json.JSONEncoder.default(self, obj)
@contextmanager
def sqlite_timelimit(conn, ms):
deadline = time.time() + (ms / 1000)
def handler():
if time.time() >= deadline:
return 1
conn.set_progress_handler(handler, 10000)
yield
conn.set_progress_handler(None, 10000)
class Datasette: class Datasette:
def __init__(self, files, num_threads=3): def __init__(self, files, num_threads=3):
self.files = files self.files = files
@ -497,15 +415,3 @@ class Datasette:
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>' '/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
) )
return app return app
class InvalidSql(Exception):
pass
def validate_sql_select(sql):
sql = sql.strip().lower()
if not sql.startswith('select '):
raise InvalidSql('Statement must begin with SELECT')
if 'pragma' in sql:
raise InvalidSql('Statement may not contain PRAGMA')

105
datasette/utils.py Normal file
View file

@ -0,0 +1,105 @@
from contextlib import contextmanager
import base64
import json
import sqlite3
import time
import urllib
def compound_pks_from_path(path):
return [
urllib.parse.unquote_plus(b) for b in path.split(',')
]
def path_from_row_pks(row, pks, use_rowid):
if use_rowid:
return urllib.parse.quote_plus(str(row['rowid']))
bits = []
for pk in pks:
bits.append(
urllib.parse.quote_plus(str(row[pk]))
)
return ','.join(bits)
def build_where_clause(args):
sql_bits = []
params = {}
for i, (key, values) in enumerate(sorted(args.items())):
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
template = {
'exact': '"{}" = :{}',
'contains': '"{}" like :{}',
'endswith': '"{}" like :{}',
'startswith': '"{}" like :{}',
'gt': '"{}" > :{}',
'gte': '"{}" >= :{}',
'lt': '"{}" < :{}',
'lte': '"{}" <= :{}',
'glob': '"{}" glob :{}',
'like': '"{}" like :{}',
}[lookup]
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
value = values[0]
value_convert = {
'contains': lambda s: '%{}%'.format(s),
'endswith': lambda s: '%{}'.format(s),
'startswith': lambda s: '{}%'.format(s),
}.get(lookup, lambda s: s)
converted = value_convert(value)
if lookup in numeric_operators and converted.isdigit():
converted = int(converted)
param_id = 'p{}'.format(i)
sql_bits.append(
template.format(column, param_id)
)
params[param_id] = converted
where_clause = ' and '.join(sql_bits)
return where_clause, params
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sqlite3.Row):
return tuple(obj)
if isinstance(obj, sqlite3.Cursor):
return list(obj)
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
}
return json.JSONEncoder.default(self, obj)
@contextmanager
def sqlite_timelimit(conn, ms):
deadline = time.time() + (ms / 1000)
def handler():
if time.time() >= deadline:
return 1
conn.set_progress_handler(handler, 10000)
yield
conn.set_progress_handler(None, 10000)
class InvalidSql(Exception):
pass
def validate_sql_select(sql):
sql = sql.strip().lower()
if not sql.startswith('select '):
raise InvalidSql('Statement must begin with SELECT')
if 'pragma' in sql:
raise InvalidSql('Statement may not contain PRAGMA')

View file

@ -2,7 +2,7 @@
Tests for various datasette helper functions. Tests for various datasette helper functions.
""" """
from datasette import app from datasette import utils
import pytest import pytest
import json import json
@ -15,7 +15,7 @@ import json
('123%2F433%2F112', ['123/433/112']), ('123%2F433%2F112', ['123/433/112']),
]) ])
def test_compound_pks_from_path(path, expected): def test_compound_pks_from_path(path, expected):
assert expected == app.compound_pks_from_path(path) assert expected == utils.compound_pks_from_path(path)
@pytest.mark.parametrize('row,pks,expected_path', [ @pytest.mark.parametrize('row,pks,expected_path', [
@ -24,7 +24,7 @@ def test_compound_pks_from_path(path, expected):
({'A': 123}, ['A'], '123'), ({'A': 123}, ['A'], '123'),
]) ])
def test_path_from_row_pks(row, pks, expected_path): def test_path_from_row_pks(row, pks, expected_path):
actual_path = app.path_from_row_pks(row, pks, False) actual_path = utils.path_from_row_pks(row, pks, False)
assert expected_path == actual_path assert expected_path == actual_path
@ -40,7 +40,7 @@ def test_path_from_row_pks(row, pks, expected_path):
def test_custom_json_encoder(obj, expected): def test_custom_json_encoder(obj, expected):
actual = json.dumps( actual = json.dumps(
obj, obj,
cls=app.CustomJSONEncoder, cls=utils.CustomJSONEncoder,
sort_keys=True sort_keys=True
) )
assert expected == actual assert expected == actual
@ -90,7 +90,7 @@ def test_custom_json_encoder(obj, expected):
), ),
]) ])
def test_build_where(args, expected_where, expected_params): def test_build_where(args, expected_where, expected_params):
actual_where, actual_params = app.build_where_clause(args) actual_where, actual_params = utils.build_where_clause(args)
assert expected_where == actual_where assert expected_where == actual_where
assert { assert {
'p{}'.format(i): param 'p{}'.format(i): param
@ -104,8 +104,8 @@ def test_build_where(args, expected_where, expected_params):
"SELECT * FROM pragma_index_info('idx52')", "SELECT * FROM pragma_index_info('idx52')",
]) ])
def test_validate_sql_select_bad(bad_sql): def test_validate_sql_select_bad(bad_sql):
with pytest.raises(app.InvalidSql): with pytest.raises(utils.InvalidSql):
app.validate_sql_select(bad_sql) utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize('good_sql', [ @pytest.mark.parametrize('good_sql', [
@ -114,4 +114,4 @@ def test_validate_sql_select_bad(bad_sql):
'select 1 + 1', 'select 1 + 1',
]) ])
def test_validate_sql_select_good(good_sql): def test_validate_sql_select_good(good_sql):
app.validate_sql_select(good_sql) utils.validate_sql_select(good_sql)