mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Refactored util functions into new utils module
This commit is contained in:
parent
1c57bd202f
commit
a8a293cd71
4 changed files with 231 additions and 220 deletions
112
datasette/app.py
112
datasette/app.py
|
|
@ -5,18 +5,23 @@ from sanic.views import HTTPMethodView
|
||||||
from sanic_jinja2 import SanicJinja2
|
from sanic_jinja2 import SanicJinja2
|
||||||
from jinja2 import FileSystemLoader
|
from jinja2 import FileSystemLoader
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import wraps
|
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import json
|
import json
|
||||||
import base64
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
from .utils import (
|
||||||
|
build_where_clause,
|
||||||
|
CustomJSONEncoder,
|
||||||
|
InvalidSql,
|
||||||
|
path_from_row_pks,
|
||||||
|
compound_pks_from_path,
|
||||||
|
sqlite_timelimit,
|
||||||
|
validate_sql_select,
|
||||||
|
)
|
||||||
|
|
||||||
app_root = Path(__file__).parent.parent
|
app_root = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
@ -373,93 +378,6 @@ def resolve_db_name(files, db_name, **kwargs):
|
||||||
return name, expected, None
|
return name, expected, None
|
||||||
|
|
||||||
|
|
||||||
def compound_pks_from_path(path):
|
|
||||||
return [
|
|
||||||
urllib.parse.unquote_plus(b) for b in path.split(',')
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def path_from_row_pks(row, pks, use_rowid):
|
|
||||||
if use_rowid:
|
|
||||||
return urllib.parse.quote_plus(str(row['rowid']))
|
|
||||||
bits = []
|
|
||||||
for pk in pks:
|
|
||||||
bits.append(
|
|
||||||
urllib.parse.quote_plus(str(row[pk]))
|
|
||||||
)
|
|
||||||
return ','.join(bits)
|
|
||||||
|
|
||||||
|
|
||||||
def build_where_clause(args):
|
|
||||||
sql_bits = []
|
|
||||||
params = {}
|
|
||||||
for i, (key, values) in enumerate(sorted(args.items())):
|
|
||||||
if '__' in key:
|
|
||||||
column, lookup = key.rsplit('__', 1)
|
|
||||||
else:
|
|
||||||
column = key
|
|
||||||
lookup = 'exact'
|
|
||||||
template = {
|
|
||||||
'exact': '"{}" = :{}',
|
|
||||||
'contains': '"{}" like :{}',
|
|
||||||
'endswith': '"{}" like :{}',
|
|
||||||
'startswith': '"{}" like :{}',
|
|
||||||
'gt': '"{}" > :{}',
|
|
||||||
'gte': '"{}" >= :{}',
|
|
||||||
'lt': '"{}" < :{}',
|
|
||||||
'lte': '"{}" <= :{}',
|
|
||||||
'glob': '"{}" glob :{}',
|
|
||||||
'like': '"{}" like :{}',
|
|
||||||
}[lookup]
|
|
||||||
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
|
|
||||||
value = values[0]
|
|
||||||
value_convert = {
|
|
||||||
'contains': lambda s: '%{}%'.format(s),
|
|
||||||
'endswith': lambda s: '%{}'.format(s),
|
|
||||||
'startswith': lambda s: '{}%'.format(s),
|
|
||||||
}.get(lookup, lambda s: s)
|
|
||||||
converted = value_convert(value)
|
|
||||||
if lookup in numeric_operators and converted.isdigit():
|
|
||||||
converted = int(converted)
|
|
||||||
param_id = 'p{}'.format(i)
|
|
||||||
sql_bits.append(
|
|
||||||
template.format(column, param_id)
|
|
||||||
)
|
|
||||||
params[param_id] = converted
|
|
||||||
where_clause = ' and '.join(sql_bits)
|
|
||||||
return where_clause, params
|
|
||||||
|
|
||||||
|
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, sqlite3.Row):
|
|
||||||
return tuple(obj)
|
|
||||||
if isinstance(obj, sqlite3.Cursor):
|
|
||||||
return list(obj)
|
|
||||||
if isinstance(obj, bytes):
|
|
||||||
# Does it encode to utf8?
|
|
||||||
try:
|
|
||||||
return obj.decode('utf8')
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return {
|
|
||||||
'$base64': True,
|
|
||||||
'encoded': base64.b64encode(obj).decode('latin1'),
|
|
||||||
}
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def sqlite_timelimit(conn, ms):
|
|
||||||
deadline = time.time() + (ms / 1000)
|
|
||||||
|
|
||||||
def handler():
|
|
||||||
if time.time() >= deadline:
|
|
||||||
return 1
|
|
||||||
conn.set_progress_handler(handler, 10000)
|
|
||||||
yield
|
|
||||||
conn.set_progress_handler(None, 10000)
|
|
||||||
|
|
||||||
|
|
||||||
class Datasette:
|
class Datasette:
|
||||||
def __init__(self, files, num_threads=3):
|
def __init__(self, files, num_threads=3):
|
||||||
self.files = files
|
self.files = files
|
||||||
|
|
@ -497,15 +415,3 @@ class Datasette:
|
||||||
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
|
'/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_json:(.jsono?)?$>'
|
||||||
)
|
)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
class InvalidSql(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def validate_sql_select(sql):
|
|
||||||
sql = sql.strip().lower()
|
|
||||||
if not sql.startswith('select '):
|
|
||||||
raise InvalidSql('Statement must begin with SELECT')
|
|
||||||
if 'pragma' in sql:
|
|
||||||
raise InvalidSql('Statement may not contain PRAGMA')
|
|
||||||
|
|
|
||||||
105
datasette/utils.py
Normal file
105
datasette/utils.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
|
||||||
|
def compound_pks_from_path(path):
|
||||||
|
return [
|
||||||
|
urllib.parse.unquote_plus(b) for b in path.split(',')
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def path_from_row_pks(row, pks, use_rowid):
|
||||||
|
if use_rowid:
|
||||||
|
return urllib.parse.quote_plus(str(row['rowid']))
|
||||||
|
bits = []
|
||||||
|
for pk in pks:
|
||||||
|
bits.append(
|
||||||
|
urllib.parse.quote_plus(str(row[pk]))
|
||||||
|
)
|
||||||
|
return ','.join(bits)
|
||||||
|
|
||||||
|
|
||||||
|
def build_where_clause(args):
|
||||||
|
sql_bits = []
|
||||||
|
params = {}
|
||||||
|
for i, (key, values) in enumerate(sorted(args.items())):
|
||||||
|
if '__' in key:
|
||||||
|
column, lookup = key.rsplit('__', 1)
|
||||||
|
else:
|
||||||
|
column = key
|
||||||
|
lookup = 'exact'
|
||||||
|
template = {
|
||||||
|
'exact': '"{}" = :{}',
|
||||||
|
'contains': '"{}" like :{}',
|
||||||
|
'endswith': '"{}" like :{}',
|
||||||
|
'startswith': '"{}" like :{}',
|
||||||
|
'gt': '"{}" > :{}',
|
||||||
|
'gte': '"{}" >= :{}',
|
||||||
|
'lt': '"{}" < :{}',
|
||||||
|
'lte': '"{}" <= :{}',
|
||||||
|
'glob': '"{}" glob :{}',
|
||||||
|
'like': '"{}" like :{}',
|
||||||
|
}[lookup]
|
||||||
|
numeric_operators = {'gt', 'gte', 'lt', 'lte'}
|
||||||
|
value = values[0]
|
||||||
|
value_convert = {
|
||||||
|
'contains': lambda s: '%{}%'.format(s),
|
||||||
|
'endswith': lambda s: '%{}'.format(s),
|
||||||
|
'startswith': lambda s: '{}%'.format(s),
|
||||||
|
}.get(lookup, lambda s: s)
|
||||||
|
converted = value_convert(value)
|
||||||
|
if lookup in numeric_operators and converted.isdigit():
|
||||||
|
converted = int(converted)
|
||||||
|
param_id = 'p{}'.format(i)
|
||||||
|
sql_bits.append(
|
||||||
|
template.format(column, param_id)
|
||||||
|
)
|
||||||
|
params[param_id] = converted
|
||||||
|
where_clause = ' and '.join(sql_bits)
|
||||||
|
return where_clause, params
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, sqlite3.Row):
|
||||||
|
return tuple(obj)
|
||||||
|
if isinstance(obj, sqlite3.Cursor):
|
||||||
|
return list(obj)
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
# Does it encode to utf8?
|
||||||
|
try:
|
||||||
|
return obj.decode('utf8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return {
|
||||||
|
'$base64': True,
|
||||||
|
'encoded': base64.b64encode(obj).decode('latin1'),
|
||||||
|
}
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def sqlite_timelimit(conn, ms):
|
||||||
|
deadline = time.time() + (ms / 1000)
|
||||||
|
|
||||||
|
def handler():
|
||||||
|
if time.time() >= deadline:
|
||||||
|
return 1
|
||||||
|
conn.set_progress_handler(handler, 10000)
|
||||||
|
yield
|
||||||
|
conn.set_progress_handler(None, 10000)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSql(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def validate_sql_select(sql):
|
||||||
|
sql = sql.strip().lower()
|
||||||
|
if not sql.startswith('select '):
|
||||||
|
raise InvalidSql('Statement must begin with SELECT')
|
||||||
|
if 'pragma' in sql:
|
||||||
|
raise InvalidSql('Statement may not contain PRAGMA')
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
Tests for various datasette helper functions.
|
Tests for various datasette helper functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datasette import app
|
from datasette import utils
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ import json
|
||||||
('123%2F433%2F112', ['123/433/112']),
|
('123%2F433%2F112', ['123/433/112']),
|
||||||
])
|
])
|
||||||
def test_compound_pks_from_path(path, expected):
|
def test_compound_pks_from_path(path, expected):
|
||||||
assert expected == app.compound_pks_from_path(path)
|
assert expected == utils.compound_pks_from_path(path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('row,pks,expected_path', [
|
@pytest.mark.parametrize('row,pks,expected_path', [
|
||||||
|
|
@ -24,7 +24,7 @@ def test_compound_pks_from_path(path, expected):
|
||||||
({'A': 123}, ['A'], '123'),
|
({'A': 123}, ['A'], '123'),
|
||||||
])
|
])
|
||||||
def test_path_from_row_pks(row, pks, expected_path):
|
def test_path_from_row_pks(row, pks, expected_path):
|
||||||
actual_path = app.path_from_row_pks(row, pks, False)
|
actual_path = utils.path_from_row_pks(row, pks, False)
|
||||||
assert expected_path == actual_path
|
assert expected_path == actual_path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ def test_path_from_row_pks(row, pks, expected_path):
|
||||||
def test_custom_json_encoder(obj, expected):
|
def test_custom_json_encoder(obj, expected):
|
||||||
actual = json.dumps(
|
actual = json.dumps(
|
||||||
obj,
|
obj,
|
||||||
cls=app.CustomJSONEncoder,
|
cls=utils.CustomJSONEncoder,
|
||||||
sort_keys=True
|
sort_keys=True
|
||||||
)
|
)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
@ -90,7 +90,7 @@ def test_custom_json_encoder(obj, expected):
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
def test_build_where(args, expected_where, expected_params):
|
def test_build_where(args, expected_where, expected_params):
|
||||||
actual_where, actual_params = app.build_where_clause(args)
|
actual_where, actual_params = utils.build_where_clause(args)
|
||||||
assert expected_where == actual_where
|
assert expected_where == actual_where
|
||||||
assert {
|
assert {
|
||||||
'p{}'.format(i): param
|
'p{}'.format(i): param
|
||||||
|
|
@ -104,8 +104,8 @@ def test_build_where(args, expected_where, expected_params):
|
||||||
"SELECT * FROM pragma_index_info('idx52')",
|
"SELECT * FROM pragma_index_info('idx52')",
|
||||||
])
|
])
|
||||||
def test_validate_sql_select_bad(bad_sql):
|
def test_validate_sql_select_bad(bad_sql):
|
||||||
with pytest.raises(app.InvalidSql):
|
with pytest.raises(utils.InvalidSql):
|
||||||
app.validate_sql_select(bad_sql)
|
utils.validate_sql_select(bad_sql)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('good_sql', [
|
@pytest.mark.parametrize('good_sql', [
|
||||||
|
|
@ -114,4 +114,4 @@ def test_validate_sql_select_bad(bad_sql):
|
||||||
'select 1 + 1',
|
'select 1 + 1',
|
||||||
])
|
])
|
||||||
def test_validate_sql_select_good(good_sql):
|
def test_validate_sql_select_good(good_sql):
|
||||||
app.validate_sql_select(good_sql)
|
utils.validate_sql_select(good_sql)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue