Extract and refactor filters into filters.py

This will help in implementing __in as a filter, refs #433
This commit is contained in:
Simon Willison 2019-04-15 14:51:20 -07:00
commit 65e913fbbc
5 changed files with 222 additions and 200 deletions

156
datasette/filters.py Normal file
View file

@ -0,0 +1,156 @@
import numbers
from .utils import detect_json1
class Filter:
key = None
display = None
no_argument = False
def where_clause(self, table, column, value, param_counter):
raise NotImplementedError
def human_clause(self, column, value):
raise NotImplementedError
class TemplatedFilter(Filter):
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
self.key = key
self.display = display
self.sql_template = sql_template
self.human_template = human_template
self.format = format
self.numeric = numeric
self.no_argument = no_argument
def where_clause(self, table, column, value, param_counter):
converted = self.format.format(value)
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
return self.sql_template.format(**kwargs), converted
def human_clause(self, column, value):
if callable(self.human_template):
template = self.human_template(column, value)
else:
template = self.human_template
if self.no_argument:
return template.format(c=column)
else:
return template.format(c=column, v=value)
class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}
def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
self.units = units
self.ureg = ureg
def lookups(self):
"Yields (lookup, display, no_argument) pairs"
for filter in self._filters:
yield filter.key, filter.display, filter.no_argument
def human_description_en(self, extra=None):
bits = []
if extra:
bits.extend(extra)
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
bits.append(filter.human_clause(column, value))
# Comma separated, with an ' and ' at the end
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
if not s:
return ''
return 'where {}'.format(s)
def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
yield column, lookup, value
def has_selections(self):
return bool(self.pairs)
def convert_unit(self, column, value):
"If the user has provided a unit in the query, convert it into the column unit, if present."
if column not in self.units:
return value
# Try to interpret the value as a unit
value = self.ureg(value)
if isinstance(value, numbers.Number):
# It's just a bare number, assume it's the column unit
return value
column_unit = self.ureg(self.units[column])
return value.to(column_unit).magnitude
def build_where_clauses(self, table):
sql_bits = []
params = {}
i = 0
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bits.append(sql_bit)
if param is not None:
if not isinstance(param, list):
param = [param]
for individual_param in param:
param_id = 'p{}'.format(i)
params[param_id] = individual_param
i += 1
return sql_bits, params

View file

@ -584,143 +584,6 @@ def table_columns(conn, table):
] ]
class Filter:
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
self.key = key
self.display = display
self.sql_template = sql_template
self.human_template = human_template
self.format = format
self.numeric = numeric
self.no_argument = no_argument
def where_clause(self, table, column, value, param_counter):
converted = self.format.format(value)
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
return self.sql_template.format(**kwargs), converted
def human_clause(self, column, value):
if callable(self.human_template):
template = self.human_template(column, value)
else:
template = self.human_template
if self.no_argument:
return template.format(c=column)
else:
return template.format(c=column, v=value)
class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
Filter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
Filter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
Filter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
Filter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
Filter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
] + ([Filter('arraycontains', 'array contains', """rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
Filter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}
def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
self.units = units
self.ureg = ureg
def lookups(self):
"Yields (lookup, display, no_argument) pairs"
for filter in self._filters:
yield filter.key, filter.display, filter.no_argument
def human_description_en(self, extra=None):
bits = []
if extra:
bits.extend(extra)
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
bits.append(filter.human_clause(column, value))
# Comma separated, with an ' and ' at the end
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
if not s:
return ''
return 'where {}'.format(s)
def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
else:
column = key
lookup = 'exact'
yield column, lookup, value
def has_selections(self):
return bool(self.pairs)
def convert_unit(self, column, value):
"If the user has provided a unit in the query, convert it into the column unit, if present."
if column not in self.units:
return value
# Try to interpret the value as a unit
value = self.ureg(value)
if isinstance(value, numbers.Number):
# It's just a bare number, assume it's the column unit
return value
column_unit = self.ureg(self.units[column])
return value.to(column_unit).magnitude
def build_where_clauses(self, table):
sql_bits = []
params = {}
for i, (column, lookup, value) in enumerate(self.selections()):
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bits.append(sql_bit)
if param is not None:
param_id = 'p{}'.format(i)
params[param_id] = param
return sql_bits, params
filter_column_re = re.compile(r'^_filter_column_\d+$') filter_column_re = re.compile(r'^_filter_column_\d+$')

View file

@ -9,7 +9,6 @@ from datasette.facets import load_facet_configs
from datasette.plugins import pm from datasette.plugins import pm
from datasette.utils import ( from datasette.utils import (
CustomRow, CustomRow,
Filters,
InterruptedError, InterruptedError,
append_querystring, append_querystring,
compound_keys_after_sql, compound_keys_after_sql,
@ -29,6 +28,7 @@ from datasette.utils import (
urlsafe_components, urlsafe_components,
value_as_boolean, value_as_boolean,
) )
from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg from .base import BaseView, DatasetteError, ureg
LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>' LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'

64
tests/test_filters.py Normal file
View file

@ -0,0 +1,64 @@
from datasette.filters import Filters
import pytest
@pytest.mark.parametrize('args,expected_where,expected_params', [
(
{
'name_english__contains': 'foo',
},
['"name_english" like :p0'],
['%foo%']
),
(
{
'foo': 'bar',
'bar__contains': 'baz',
},
['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar']
),
(
{
'foo__startswith': 'bar',
'bar__endswith': 'baz',
},
['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%']
),
(
{
'foo__lt': '1',
'bar__gt': '2',
'baz__gte': '3',
'bax__lte': '4',
},
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1]
),
(
{
'foo__like': '2%2',
'zax__glob': '3*',
},
['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*']
),
(
{
'foo__isnull': '1',
'baz__isnull': '1',
'bar__gt': '10'
},
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10]
),
])
def test_build_where(args, expected_where, expected_params):
f = Filters(sorted(args.items()))
sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits
assert {
'p{}'.format(i): param
for i, param in enumerate(expected_params)
} == actual_params

View file

@ -3,6 +3,7 @@ Tests for various datasette helper functions.
""" """
from datasette import utils from datasette import utils
from datasette.filters import Filters
import json import json
import os import os
import pytest import pytest
@ -133,68 +134,6 @@ def test_custom_json_encoder(obj, expected):
assert expected == actual assert expected == actual
@pytest.mark.parametrize('args,expected_where,expected_params', [
(
{
'name_english__contains': 'foo',
},
['"name_english" like :p0'],
['%foo%']
),
(
{
'foo': 'bar',
'bar__contains': 'baz',
},
['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar']
),
(
{
'foo__startswith': 'bar',
'bar__endswith': 'baz',
},
['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%']
),
(
{
'foo__lt': '1',
'bar__gt': '2',
'baz__gte': '3',
'bax__lte': '4',
},
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1]
),
(
{
'foo__like': '2%2',
'zax__glob': '3*',
},
['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*']
),
(
{
'foo__isnull': '1',
'baz__isnull': '1',
'bar__gt': '10'
},
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10]
),
])
def test_build_where(args, expected_where, expected_params):
f = utils.Filters(sorted(args.items()))
sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits
assert {
'p{}'.format(i): param
for i, param in enumerate(expected_params)
} == actual_params
@pytest.mark.parametrize('bad_sql', [ @pytest.mark.parametrize('bad_sql', [
'update blah;', 'update blah;',
'PRAGMA case_sensitive_like = true' 'PRAGMA case_sensitive_like = true'