From 78e45ead4d771007c57b307edf8fc920101f8733 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 10 Apr 2019 08:17:19 -0700 Subject: [PATCH] New ?tags__arraycontains=tag lookup against JSON fields Part one of supporting facet-by-JSON-array, refs #359 --- datasette/utils.py | 25 +++++++++++++++++++++---- datasette/views/table.py | 2 +- tests/fixtures.py | 33 +++++++++++++++++---------------- tests/test_api.py | 21 ++++++++++++++++++--- tests/test_csv.py | 32 ++++++++++++++++---------------- tests/test_utils.py | 2 +- 6 files changed, 74 insertions(+), 41 deletions(-) diff --git a/datasette/utils.py b/datasette/utils.py index c62e0713..bb5c17d6 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -565,6 +565,16 @@ def detect_fts_sql(table): '''.format(table=table) +def detect_json1(conn=None): + if conn is None: + conn = sqlite3.connect(":memory:") + try: + conn.execute("SELECT json('{}')") + return True + except Exception: + return False + + def table_columns(conn, table): return [ r[1] @@ -584,7 +594,7 @@ class Filter: self.numeric = numeric self.no_argument = no_argument - def where_clause(self, column, value, param_counter): + def where_clause(self, table, column, value, param_counter): converted = self.format.format(value) if self.numeric and converted.isdigit(): converted = int(converted) @@ -597,6 +607,7 @@ class Filter: kwargs = { 'c': column, 'p': 'p{}'.format(param_counter), + 't': table, } return self.sql_template.format(**kwargs), converted @@ -613,6 +624,7 @@ class Filter: class Filters: _filters = [ + # key, display, sql_template, human_template, format=, numeric=, no_argument= Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'), Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'), Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'), @@ -624,6 +636,11 @@ class Filters: Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), + ] + ([Filter('arraycontains', 'array contains', """rowid in ( + select {t}.rowid from {t}, json_each({t}.{c}) j + where j.value = :{p} + )""", '{c} contains "{v}"') + ] if detect_json1() else []) + [ Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True), Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True), Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True), @@ -677,7 +694,7 @@ class Filters: return bool(self.pairs) def convert_unit(self, column, value): - "If the user has provided a unit in the quey, convert it into the column unit, if present." + "If the user has provided a unit in the query, convert it into the column unit, if present." if column not in self.units: return value @@ -690,13 +707,13 @@ class Filters: column_unit = self.ureg(self.units[column]) return value.to(column_unit).magnitude - def build_where_clauses(self): + def build_where_clauses(self, table): sql_bits = [] params = {} for i, (column, lookup, value) in enumerate(self.selections()): filter = self._filters_by_key.get(lookup, None) if filter: - sql_bit, param = filter.where_clause(column, self.convert_unit(column, value), i) + sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i) sql_bits.append(sql_bit) if param is not None: param_id = 'p{}'.format(i) diff --git a/datasette/views/table.py b/datasette/views/table.py index 87feda2d..2727565b 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -293,7 +293,7 @@ class TableView(RowTableShared): table_metadata = self.ds.table_metadata(database, table) units = table_metadata.get("units", {}) filters = Filters(sorted(other_args.items()), units, ureg) - where_clauses, params = filters.build_where_clauses() + where_clauses, params = filters.build_where_clauses(table) # _search support: fts_table = await self.ds.execute_against_connection_in_thread( diff --git a/tests/fixtures.py b/tests/fixtures.py index 8f5f0b68..b3b38c95 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -523,26 +523,27 @@ CREATE TABLE facetable ( state text, city_id integer, neighborhood text, + tags text, FOREIGN KEY ("city_id") REFERENCES [facet_cities](id) ); INSERT INTO facetable - (planet_int, on_earth, state, city_id, neighborhood) + (planet_int, on_earth, state, city_id, neighborhood, tags) VALUES - (1, 1, 'CA', 1, 'Mission'), - (1, 1, 'CA', 1, 'Dogpatch'), - (1, 1, 'CA', 1, 'SOMA'), - (1, 1, 'CA', 1, 'Tenderloin'), - (1, 1, 'CA', 1, 'Bernal Heights'), - (1, 1, 'CA', 1, 'Hayes Valley'), - (1, 1, 'CA', 2, 'Hollywood'), - (1, 1, 'CA', 2, 'Downtown'), - (1, 1, 'CA', 2, 'Los Feliz'), - (1, 1, 'CA', 2, 'Koreatown'), - (1, 1, 'MI', 3, 'Downtown'), - (1, 1, 'MI', 3, 'Greektown'), - (1, 1, 'MI', 3, 'Corktown'), - (1, 1, 'MI', 3, 'Mexicantown'), - (2, 0, 'MC', 4, 'Arcadia Planitia') + (1, 1, 'CA', 1, 'Mission', '["tag1", "tag2"]'), + (1, 1, 'CA', 1, 'Dogpatch', '["tag1", "tag3"]'), + (1, 1, 'CA', 1, 'SOMA', '[]'), + (1, 1, 'CA', 1, 'Tenderloin', '[]'), + (1, 1, 'CA', 1, 'Bernal Heights', '[]'), + (1, 1, 'CA', 1, 'Hayes Valley', '[]'), + (1, 1, 'CA', 2, 'Hollywood', '[]'), + (1, 1, 'CA', 2, 'Downtown', '[]'), + (1, 1, 'CA', 2, 'Los Feliz', '[]'), + (1, 1, 'CA', 2, 'Koreatown', '[]'), + (1, 1, 'MI', 3, 'Downtown', '[]'), + (1, 1, 'MI', 3, 'Greektown', '[]'), + (1, 1, 'MI', 3, 'Corktown', '[]'), + (1, 1, 'MI', 3, 'Mexicantown', '[]'), + (2, 0, 'MC', 4, 'Arcadia Planitia', '[]') ; INSERT INTO simple_primary_key VALUES (1, 'hello'); diff --git a/tests/test_api.py b/tests/test_api.py index a50148c1..188a60e8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,4 @@ +from datasette.utils import detect_json1 from .fixtures import ( # noqa app_client, app_client_no_files, @@ -115,7 +116,7 @@ def test_database_page(app_client): 'hidden': False, 'primary_keys': ['id'], }, { - 'columns': ['pk', 'planet_int', 'on_earth', 'state', 'city_id', 'neighborhood'], + 'columns': ['pk', 'planet_int', 'on_earth', 'state', 'city_id', 'neighborhood', 'tags'], 'name': 'facetable', 'count': 15, 'foreign_keys': { @@ -882,6 +883,18 @@ def test_table_filter_queries(app_client, path, expected_rows): assert expected_rows == response.json['rows'] +@pytest.mark.skipif( + not detect_json1(), + reason="Requires the SQLite json1 module" +) +def test_table_filter_json_arraycontains(app_client): + response = app_client.get("/fixtures/facetable.json?tags__arraycontains=tag1") + assert [ + [1, 1, 1, 'CA', 1, 'Mission', '["tag1", "tag2"]'], + [2, 1, 1, 'CA', 1, 'Dogpatch', '["tag1", "tag3"]'] + ] == response.json['rows'] + + def test_max_returned_rows(app_client): response = app_client.get( '/fixtures.json?sql=select+content+from+no_primary_key' @@ -1244,7 +1257,8 @@ def test_expand_labels(app_client): "value": 1, "label": "San Francisco" }, - "neighborhood": "Dogpatch" + "neighborhood": "Dogpatch", + "tags": '["tag1", "tag3"]' }, "13": { "pk": 13, @@ -1255,7 +1269,8 @@ def test_expand_labels(app_client): "value": 3, "label": "Detroit" }, - "neighborhood": "Corktown" + "neighborhood": "Corktown", + "tags": '[]', } } == response.json diff --git a/tests/test_csv.py b/tests/test_csv.py index 357838f6..aa78620a 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -17,22 +17,22 @@ world '''.replace('\n', '\r\n') EXPECTED_TABLE_WITH_LABELS_CSV = ''' -pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood -1,1,1,CA,1,San Francisco,Mission -2,1,1,CA,1,San Francisco,Dogpatch -3,1,1,CA,1,San Francisco,SOMA -4,1,1,CA,1,San Francisco,Tenderloin -5,1,1,CA,1,San Francisco,Bernal Heights -6,1,1,CA,1,San Francisco,Hayes Valley -7,1,1,CA,2,Los Angeles,Hollywood -8,1,1,CA,2,Los Angeles,Downtown -9,1,1,CA,2,Los Angeles,Los Feliz -10,1,1,CA,2,Los Angeles,Koreatown -11,1,1,MI,3,Detroit,Downtown -12,1,1,MI,3,Detroit,Greektown -13,1,1,MI,3,Detroit,Corktown -14,1,1,MI,3,Detroit,Mexicantown -15,2,0,MC,4,Memnonia,Arcadia Planitia +pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags +1,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]" +2,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]" +3,1,1,CA,1,San Francisco,SOMA,[] +4,1,1,CA,1,San Francisco,Tenderloin,[] +5,1,1,CA,1,San Francisco,Bernal Heights,[] +6,1,1,CA,1,San Francisco,Hayes Valley,[] +7,1,1,CA,2,Los Angeles,Hollywood,[] +8,1,1,CA,2,Los Angeles,Downtown,[] +9,1,1,CA,2,Los Angeles,Los Feliz,[] +10,1,1,CA,2,Los Angeles,Koreatown,[] +11,1,1,MI,3,Detroit,Downtown,[] +12,1,1,MI,3,Detroit,Greektown,[] +13,1,1,MI,3,Detroit,Corktown,[] +14,1,1,MI,3,Detroit,Mexicantown,[] +15,2,0,MC,4,Memnonia,Arcadia Planitia,[] '''.lstrip().replace('\n', '\r\n') diff --git a/tests/test_utils.py b/tests/test_utils.py index 9a00b4b4..07074e72 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -187,7 +187,7 @@ def test_custom_json_encoder(obj, expected): ]) def test_build_where(args, expected_where, expected_params): f = utils.Filters(sorted(args.items())) - sql_bits, actual_params = f.build_where_clauses() + sql_bits, actual_params = f.build_where_clauses("table") assert expected_where == sql_bits assert { 'p{}'.format(i): param