diff --git a/datasette/filters.py b/datasette/filters.py
new file mode 100644
index 00000000..5fd722f3
--- /dev/null
+++ b/datasette/filters.py
@@ -0,0 +1,156 @@
+import numbers
+from .utils import detect_json1
+
+
+class Filter:
+ key = None
+ display = None
+ no_argument = False
+
+ def where_clause(self, table, column, value, param_counter):
+ raise NotImplementedError
+
+ def human_clause(self, column, value):
+ raise NotImplementedError
+
+
+class TemplatedFilter(Filter):
+ def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
+ self.key = key
+ self.display = display
+ self.sql_template = sql_template
+ self.human_template = human_template
+ self.format = format
+ self.numeric = numeric
+ self.no_argument = no_argument
+
+ def where_clause(self, table, column, value, param_counter):
+ converted = self.format.format(value)
+ if self.numeric and converted.isdigit():
+ converted = int(converted)
+ if self.no_argument:
+ kwargs = {
+ 'c': column,
+ }
+ converted = None
+ else:
+ kwargs = {
+ 'c': column,
+ 'p': 'p{}'.format(param_counter),
+ 't': table,
+ }
+ return self.sql_template.format(**kwargs), converted
+
+ def human_clause(self, column, value):
+ if callable(self.human_template):
+ template = self.human_template(column, value)
+ else:
+ template = self.human_template
+ if self.no_argument:
+ return template.format(c=column)
+ else:
+ return template.format(c=column, v=value)
+
+
+class Filters:
+ _filters = [
+ # key, display, sql_template, human_template, format=, numeric=, no_argument=
+ TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
+ TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
+ TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
+ TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
+ TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
+ TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
+ TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
+ TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
+ TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
+ TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
+ TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
+ ] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in (
+ select {t}.rowid from {t}, json_each({t}.{c}) j
+ where j.value = :{p}
+ )""", '{c} contains "{v}"')
+ ] if detect_json1() else []) + [
+ TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
+ TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
+ TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
+ TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
+ ]
+ _filters_by_key = {
+ f.key: f for f in _filters
+ }
+
+ def __init__(self, pairs, units={}, ureg=None):
+ self.pairs = pairs
+ self.units = units
+ self.ureg = ureg
+
+ def lookups(self):
+ "Yields (lookup, display, no_argument) pairs"
+ for filter in self._filters:
+ yield filter.key, filter.display, filter.no_argument
+
+ def human_description_en(self, extra=None):
+ bits = []
+ if extra:
+ bits.extend(extra)
+ for column, lookup, value in self.selections():
+ filter = self._filters_by_key.get(lookup, None)
+ if filter:
+ bits.append(filter.human_clause(column, value))
+ # Comma separated, with an ' and ' at the end
+ and_bits = []
+ commas, tail = bits[:-1], bits[-1:]
+ if commas:
+ and_bits.append(', '.join(commas))
+ if tail:
+ and_bits.append(tail[0])
+ s = ' and '.join(and_bits)
+ if not s:
+ return ''
+ return 'where {}'.format(s)
+
+ def selections(self):
+ "Yields (column, lookup, value) tuples"
+ for key, value in self.pairs:
+ if '__' in key:
+ column, lookup = key.rsplit('__', 1)
+ else:
+ column = key
+ lookup = 'exact'
+ yield column, lookup, value
+
+ def has_selections(self):
+ return bool(self.pairs)
+
+ def convert_unit(self, column, value):
+ "If the user has provided a unit in the query, convert it into the column unit, if present."
+ if column not in self.units:
+ return value
+
+ # Try to interpret the value as a unit
+ value = self.ureg(value)
+ if isinstance(value, numbers.Number):
+ # It's just a bare number, assume it's the column unit
+ return value
+
+ column_unit = self.ureg(self.units[column])
+ return value.to(column_unit).magnitude
+
+ def build_where_clauses(self, table):
+ sql_bits = []
+ params = {}
+ i = 0
+ for column, lookup, value in self.selections():
+ filter = self._filters_by_key.get(lookup, None)
+ if filter:
+ sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
+ sql_bits.append(sql_bit)
+ if param is not None:
+ if not isinstance(param, list):
+ param = [param]
+ for individual_param in param:
+ param_id = 'p{}'.format(i)
+ params[param_id] = individual_param
+ i += 1
+ return sql_bits, params
diff --git a/datasette/utils.py b/datasette/utils.py
index bb5c17d6..0c161ac6 100644
--- a/datasette/utils.py
+++ b/datasette/utils.py
@@ -584,143 +584,6 @@ def table_columns(conn, table):
]
-class Filter:
- def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
- self.key = key
- self.display = display
- self.sql_template = sql_template
- self.human_template = human_template
- self.format = format
- self.numeric = numeric
- self.no_argument = no_argument
-
- def where_clause(self, table, column, value, param_counter):
- converted = self.format.format(value)
- if self.numeric and converted.isdigit():
- converted = int(converted)
- if self.no_argument:
- kwargs = {
- 'c': column,
- }
- converted = None
- else:
- kwargs = {
- 'c': column,
- 'p': 'p{}'.format(param_counter),
- 't': table,
- }
- return self.sql_template.format(**kwargs), converted
-
- def human_clause(self, column, value):
- if callable(self.human_template):
- template = self.human_template(column, value)
- else:
- template = self.human_template
- if self.no_argument:
- return template.format(c=column)
- else:
- return template.format(c=column, v=value)
-
-
-class Filters:
- _filters = [
- # key, display, sql_template, human_template, format=, numeric=, no_argument=
- Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
- Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
- Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
- Filter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
- Filter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
- Filter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
- Filter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
- Filter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
- Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
- Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
- Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
- ] + ([Filter('arraycontains', 'array contains', """rowid in (
- select {t}.rowid from {t}, json_each({t}.{c}) j
- where j.value = :{p}
- )""", '{c} contains "{v}"')
- ] if detect_json1() else []) + [
- Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
- Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
- Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
- Filter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
- ]
- _filters_by_key = {
- f.key: f for f in _filters
- }
-
- def __init__(self, pairs, units={}, ureg=None):
- self.pairs = pairs
- self.units = units
- self.ureg = ureg
-
- def lookups(self):
- "Yields (lookup, display, no_argument) pairs"
- for filter in self._filters:
- yield filter.key, filter.display, filter.no_argument
-
- def human_description_en(self, extra=None):
- bits = []
- if extra:
- bits.extend(extra)
- for column, lookup, value in self.selections():
- filter = self._filters_by_key.get(lookup, None)
- if filter:
- bits.append(filter.human_clause(column, value))
- # Comma separated, with an ' and ' at the end
- and_bits = []
- commas, tail = bits[:-1], bits[-1:]
- if commas:
- and_bits.append(', '.join(commas))
- if tail:
- and_bits.append(tail[0])
- s = ' and '.join(and_bits)
- if not s:
- return ''
- return 'where {}'.format(s)
-
- def selections(self):
- "Yields (column, lookup, value) tuples"
- for key, value in self.pairs:
- if '__' in key:
- column, lookup = key.rsplit('__', 1)
- else:
- column = key
- lookup = 'exact'
- yield column, lookup, value
-
- def has_selections(self):
- return bool(self.pairs)
-
- def convert_unit(self, column, value):
- "If the user has provided a unit in the query, convert it into the column unit, if present."
- if column not in self.units:
- return value
-
- # Try to interpret the value as a unit
- value = self.ureg(value)
- if isinstance(value, numbers.Number):
- # It's just a bare number, assume it's the column unit
- return value
-
- column_unit = self.ureg(self.units[column])
- return value.to(column_unit).magnitude
-
- def build_where_clauses(self, table):
- sql_bits = []
- params = {}
- for i, (column, lookup, value) in enumerate(self.selections()):
- filter = self._filters_by_key.get(lookup, None)
- if filter:
- sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
- sql_bits.append(sql_bit)
- if param is not None:
- param_id = 'p{}'.format(i)
- params[param_id] = param
- return sql_bits, params
-
-
filter_column_re = re.compile(r'^_filter_column_\d+$')
diff --git a/datasette/views/table.py b/datasette/views/table.py
index 5923ac92..2c356bda 100644
--- a/datasette/views/table.py
+++ b/datasette/views/table.py
@@ -7,7 +7,6 @@ from sanic.request import RequestParameters
from datasette.plugins import pm
from datasette.utils import (
CustomRow,
- Filters,
InterruptedError,
append_querystring,
compound_keys_after_sql,
@@ -27,6 +26,7 @@ from datasette.utils import (
urlsafe_components,
value_as_boolean,
)
+from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg
LINK_WITH_LABEL = '{label} {id}'
diff --git a/tests/test_filters.py b/tests/test_filters.py
new file mode 100644
index 00000000..b0cb3f34
--- /dev/null
+++ b/tests/test_filters.py
@@ -0,0 +1,64 @@
+from datasette.filters import Filters
+import pytest
+
+
+@pytest.mark.parametrize('args,expected_where,expected_params', [
+ (
+ {
+ 'name_english__contains': 'foo',
+ },
+ ['"name_english" like :p0'],
+ ['%foo%']
+ ),
+ (
+ {
+ 'foo': 'bar',
+ 'bar__contains': 'baz',
+ },
+ ['"bar" like :p0', '"foo" = :p1'],
+ ['%baz%', 'bar']
+ ),
+ (
+ {
+ 'foo__startswith': 'bar',
+ 'bar__endswith': 'baz',
+ },
+ ['"bar" like :p0', '"foo" like :p1'],
+ ['%baz', 'bar%']
+ ),
+ (
+ {
+ 'foo__lt': '1',
+ 'bar__gt': '2',
+ 'baz__gte': '3',
+ 'bax__lte': '4',
+ },
+ ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
+ [2, 4, 3, 1]
+ ),
+ (
+ {
+ 'foo__like': '2%2',
+ 'zax__glob': '3*',
+ },
+ ['"foo" like :p0', '"zax" glob :p1'],
+ ['2%2', '3*']
+ ),
+ (
+ {
+ 'foo__isnull': '1',
+ 'baz__isnull': '1',
+ 'bar__gt': '10'
+ },
+ ['"bar" > :p0', '"baz" is null', '"foo" is null'],
+ [10]
+ ),
+])
+def test_build_where(args, expected_where, expected_params):
+ f = Filters(sorted(args.items()))
+ sql_bits, actual_params = f.build_where_clauses("table")
+ assert expected_where == sql_bits
+ assert {
+ 'p{}'.format(i): param
+ for i, param in enumerate(expected_params)
+ } == actual_params
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 07074e72..1ca202f4 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,6 +3,7 @@ Tests for various datasette helper functions.
"""
from datasette import utils
+from datasette.filters import Filters
import json
import os
import pytest
@@ -133,68 +134,6 @@ def test_custom_json_encoder(obj, expected):
assert expected == actual
-@pytest.mark.parametrize('args,expected_where,expected_params', [
- (
- {
- 'name_english__contains': 'foo',
- },
- ['"name_english" like :p0'],
- ['%foo%']
- ),
- (
- {
- 'foo': 'bar',
- 'bar__contains': 'baz',
- },
- ['"bar" like :p0', '"foo" = :p1'],
- ['%baz%', 'bar']
- ),
- (
- {
- 'foo__startswith': 'bar',
- 'bar__endswith': 'baz',
- },
- ['"bar" like :p0', '"foo" like :p1'],
- ['%baz', 'bar%']
- ),
- (
- {
- 'foo__lt': '1',
- 'bar__gt': '2',
- 'baz__gte': '3',
- 'bax__lte': '4',
- },
- ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
- [2, 4, 3, 1]
- ),
- (
- {
- 'foo__like': '2%2',
- 'zax__glob': '3*',
- },
- ['"foo" like :p0', '"zax" glob :p1'],
- ['2%2', '3*']
- ),
- (
- {
- 'foo__isnull': '1',
- 'baz__isnull': '1',
- 'bar__gt': '10'
- },
- ['"bar" > :p0', '"baz" is null', '"foo" is null'],
- [10]
- ),
-])
-def test_build_where(args, expected_where, expected_params):
- f = utils.Filters(sorted(args.items()))
- sql_bits, actual_params = f.build_where_clauses("table")
- assert expected_where == sql_bits
- assert {
- 'p{}'.format(i): param
- for i, param in enumerate(expected_params)
- } == actual_params
-
-
@pytest.mark.parametrize('bad_sql', [
'update blah;',
'PRAGMA case_sensitive_like = true'