diff --git a/datasette/filters.py b/datasette/filters.py new file mode 100644 index 00000000..5fd722f3 --- /dev/null +++ b/datasette/filters.py @@ -0,0 +1,156 @@ +import numbers +from .utils import detect_json1 + + +class Filter: + key = None + display = None + no_argument = False + + def where_clause(self, table, column, value, param_counter): + raise NotImplementedError + + def human_clause(self, column, value): + raise NotImplementedError + + +class TemplatedFilter(Filter): + def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): + self.key = key + self.display = display + self.sql_template = sql_template + self.human_template = human_template + self.format = format + self.numeric = numeric + self.no_argument = no_argument + + def where_clause(self, table, column, value, param_counter): + converted = self.format.format(value) + if self.numeric and converted.isdigit(): + converted = int(converted) + if self.no_argument: + kwargs = { + 'c': column, + } + converted = None + else: + kwargs = { + 'c': column, + 'p': 'p{}'.format(param_counter), + 't': table, + } + return self.sql_template.format(**kwargs), converted + + def human_clause(self, column, value): + if callable(self.human_template): + template = self.human_template(column, value) + else: + template = self.human_template + if self.no_argument: + return template.format(c=column) + else: + return template.format(c=column, v=value) + + +class Filters: + _filters = [ + # key, display, sql_template, human_template, format=, numeric=, no_argument= + TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'), + TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'), + TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'), + TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'), + TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'), + TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True), + TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True), + TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True), + TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), + TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), + TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), + ] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in ( + select {t}.rowid from {t}, json_each({t}.{c}) j + where j.value = :{p} + )""", '{c} contains "{v}"') + ] if detect_json1() else []) + [ + TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True), + TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True), + TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True), + TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True), + ] + _filters_by_key = { + f.key: f for f in _filters + } + + def __init__(self, pairs, units={}, ureg=None): + self.pairs = pairs + self.units = units + self.ureg = ureg + + def lookups(self): + "Yields (lookup, display, no_argument) pairs" + for filter in self._filters: + yield filter.key, filter.display, filter.no_argument + + def human_description_en(self, extra=None): + bits = [] + if extra: + bits.extend(extra) + for column, lookup, value in self.selections(): + filter = self._filters_by_key.get(lookup, None) + if filter: + bits.append(filter.human_clause(column, value)) + # Comma separated, with an ' and ' at the end + and_bits = [] + commas, tail = bits[:-1], bits[-1:] + if commas: + and_bits.append(', '.join(commas)) + if tail: + and_bits.append(tail[0]) + s = ' and '.join(and_bits) + if not s: + return '' + return 'where {}'.format(s) + + def selections(self): + "Yields (column, lookup, value) tuples" + for key, value in self.pairs: + if '__' in key: + column, lookup = key.rsplit('__', 1) + else: + column = key + lookup = 'exact' + yield column, lookup, value + + def has_selections(self): + return bool(self.pairs) + + def convert_unit(self, column, value): + "If the user has provided a unit in the query, convert it into the column unit, if present." + if column not in self.units: + return value + + # Try to interpret the value as a unit + value = self.ureg(value) + if isinstance(value, numbers.Number): + # It's just a bare number, assume it's the column unit + return value + + column_unit = self.ureg(self.units[column]) + return value.to(column_unit).magnitude + + def build_where_clauses(self, table): + sql_bits = [] + params = {} + i = 0 + for column, lookup, value in self.selections(): + filter = self._filters_by_key.get(lookup, None) + if filter: + sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i) + sql_bits.append(sql_bit) + if param is not None: + if not isinstance(param, list): + param = [param] + for individual_param in param: + param_id = 'p{}'.format(i) + params[param_id] = individual_param + i += 1 + return sql_bits, params diff --git a/datasette/utils.py b/datasette/utils.py index bb5c17d6..0c161ac6 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -584,143 +584,6 @@ def table_columns(conn, table): ] -class Filter: - def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): - self.key = key - self.display = display - self.sql_template = sql_template - self.human_template = human_template - self.format = format - self.numeric = numeric - self.no_argument = no_argument - - def where_clause(self, table, column, value, param_counter): - converted = self.format.format(value) - if self.numeric and converted.isdigit(): - converted = int(converted) - if self.no_argument: - kwargs = { - 'c': column, - } - converted = None - else: - kwargs = { - 'c': column, - 'p': 'p{}'.format(param_counter), - 't': table, - } - return self.sql_template.format(**kwargs), converted - - def human_clause(self, column, value): - if callable(self.human_template): - template = self.human_template(column, value) - else: - template = self.human_template - if self.no_argument: - return template.format(c=column) - else: - return template.format(c=column, v=value) - - -class Filters: - _filters = [ - # key, display, sql_template, human_template, format=, numeric=, no_argument= - Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'), - Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'), - Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'), - Filter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'), - Filter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'), - Filter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True), - Filter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True), - Filter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True), - Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), - Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), - Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), - ] + ([Filter('arraycontains', 'array contains', """rowid in ( - select {t}.rowid from {t}, json_each({t}.{c}) j - where j.value = :{p} - )""", '{c} contains "{v}"') - ] if detect_json1() else []) + [ - Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True), - Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True), - Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True), - Filter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True), - ] - _filters_by_key = { - f.key: f for f in _filters - } - - def __init__(self, pairs, units={}, ureg=None): - self.pairs = pairs - self.units = units - self.ureg = ureg - - def lookups(self): - "Yields (lookup, display, no_argument) pairs" - for filter in self._filters: - yield filter.key, filter.display, filter.no_argument - - def human_description_en(self, extra=None): - bits = [] - if extra: - bits.extend(extra) - for column, lookup, value in self.selections(): - filter = self._filters_by_key.get(lookup, None) - if filter: - bits.append(filter.human_clause(column, value)) - # Comma separated, with an ' and ' at the end - and_bits = [] - commas, tail = bits[:-1], bits[-1:] - if commas: - and_bits.append(', '.join(commas)) - if tail: - and_bits.append(tail[0]) - s = ' and '.join(and_bits) - if not s: - return '' - return 'where {}'.format(s) - - def selections(self): - "Yields (column, lookup, value) tuples" - for key, value in self.pairs: - if '__' in key: - column, lookup = key.rsplit('__', 1) - else: - column = key - lookup = 'exact' - yield column, lookup, value - - def has_selections(self): - return bool(self.pairs) - - def convert_unit(self, column, value): - "If the user has provided a unit in the query, convert it into the column unit, if present." - if column not in self.units: - return value - - # Try to interpret the value as a unit - value = self.ureg(value) - if isinstance(value, numbers.Number): - # It's just a bare number, assume it's the column unit - return value - - column_unit = self.ureg(self.units[column]) - return value.to(column_unit).magnitude - - def build_where_clauses(self, table): - sql_bits = [] - params = {} - for i, (column, lookup, value) in enumerate(self.selections()): - filter = self._filters_by_key.get(lookup, None) - if filter: - sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i) - sql_bits.append(sql_bit) - if param is not None: - param_id = 'p{}'.format(i) - params[param_id] = param - return sql_bits, params - - filter_column_re = re.compile(r'^_filter_column_\d+$') diff --git a/datasette/views/table.py b/datasette/views/table.py index 5923ac92..2c356bda 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -7,7 +7,6 @@ from sanic.request import RequestParameters from datasette.plugins import pm from datasette.utils import ( CustomRow, - Filters, InterruptedError, append_querystring, compound_keys_after_sql, @@ -27,6 +26,7 @@ from datasette.utils import ( urlsafe_components, value_as_boolean, ) +from datasette.filters import Filters from .base import BaseView, DatasetteError, ureg LINK_WITH_LABEL = '{label} {id}' diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 00000000..b0cb3f34 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,64 @@ +from datasette.filters import Filters +import pytest + + +@pytest.mark.parametrize('args,expected_where,expected_params', [ + ( + { + 'name_english__contains': 'foo', + }, + ['"name_english" like :p0'], + ['%foo%'] + ), + ( + { + 'foo': 'bar', + 'bar__contains': 'baz', + }, + ['"bar" like :p0', '"foo" = :p1'], + ['%baz%', 'bar'] + ), + ( + { + 'foo__startswith': 'bar', + 'bar__endswith': 'baz', + }, + ['"bar" like :p0', '"foo" like :p1'], + ['%baz', 'bar%'] + ), + ( + { + 'foo__lt': '1', + 'bar__gt': '2', + 'baz__gte': '3', + 'bax__lte': '4', + }, + ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'], + [2, 4, 3, 1] + ), + ( + { + 'foo__like': '2%2', + 'zax__glob': '3*', + }, + ['"foo" like :p0', '"zax" glob :p1'], + ['2%2', '3*'] + ), + ( + { + 'foo__isnull': '1', + 'baz__isnull': '1', + 'bar__gt': '10' + }, + ['"bar" > :p0', '"baz" is null', '"foo" is null'], + [10] + ), +]) +def test_build_where(args, expected_where, expected_params): + f = Filters(sorted(args.items())) + sql_bits, actual_params = f.build_where_clauses("table") + assert expected_where == sql_bits + assert { + 'p{}'.format(i): param + for i, param in enumerate(expected_params) + } == actual_params diff --git a/tests/test_utils.py b/tests/test_utils.py index 07074e72..1ca202f4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ Tests for various datasette helper functions. """ from datasette import utils +from datasette.filters import Filters import json import os import pytest @@ -133,68 +134,6 @@ def test_custom_json_encoder(obj, expected): assert expected == actual -@pytest.mark.parametrize('args,expected_where,expected_params', [ - ( - { - 'name_english__contains': 'foo', - }, - ['"name_english" like :p0'], - ['%foo%'] - ), - ( - { - 'foo': 'bar', - 'bar__contains': 'baz', - }, - ['"bar" like :p0', '"foo" = :p1'], - ['%baz%', 'bar'] - ), - ( - { - 'foo__startswith': 'bar', - 'bar__endswith': 'baz', - }, - ['"bar" like :p0', '"foo" like :p1'], - ['%baz', 'bar%'] - ), - ( - { - 'foo__lt': '1', - 'bar__gt': '2', - 'baz__gte': '3', - 'bax__lte': '4', - }, - ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'], - [2, 4, 3, 1] - ), - ( - { - 'foo__like': '2%2', - 'zax__glob': '3*', - }, - ['"foo" like :p0', '"zax" glob :p1'], - ['2%2', '3*'] - ), - ( - { - 'foo__isnull': '1', - 'baz__isnull': '1', - 'bar__gt': '10' - }, - ['"bar" > :p0', '"baz" is null', '"foo" is null'], - [10] - ), -]) -def test_build_where(args, expected_where, expected_params): - f = utils.Filters(sorted(args.items())) - sql_bits, actual_params = f.build_where_clauses("table") - assert expected_where == sql_bits - assert { - 'p{}'.format(i): param - for i, param in enumerate(expected_params) - } == actual_params - - @pytest.mark.parametrize('bad_sql', [ 'update blah;', 'PRAGMA case_sensitive_like = true'