diff --git a/datasette/filters.py b/datasette/filters.py index 95cc5f37..3b8e2a68 100644 --- a/datasette/filters.py +++ b/datasette/filters.py @@ -3,9 +3,13 @@ from datasette.resources import DatabaseResource from datasette.views.base import DatasetteError from datasette.utils.asgi import BadRequest import json +import re from .utils import detect_json1, escape_sqlite, path_with_removed_args +looks_like_number_re = re.compile(r"^-?(?:\d+(?:\.\d+)?|\.\d+)$") + + @hookimpl(specname="filters_from_request") def where_filters(request, database, datasette): # This one deals with ?_where= @@ -408,14 +412,30 @@ class Filters: def has_selections(self): return bool(self.pairs) - def build_where_clauses(self, table): + def build_where_clauses(self, table, column_details=None): sql_bits = [] params = {} i = 0 for column, lookup, value in self.selections(): filter = self._filters_by_key.get(lookup, None) if filter: - sql_bit, param = filter.where_clause(table, column, value, i) + param_name = f"p{i}" + if self.should_compare_with_numeric_value( + column_details, column, lookup, value + ): + numeric_comparison = ( + f'"{column}" = :{param_name} or ' + f"(typeof(\"{column}\") in ('integer', 'real') " + f'and "{column}" = CAST(:{param_name} AS NUMERIC))' + ) + sql_bit = ( + f"not ({numeric_comparison})" + if lookup == "not" + else f"({numeric_comparison})" + ) + param = value + else: + sql_bit, param = filter.where_clause(table, column, value, i) sql_bits.append(sql_bit) if param is not None: if not isinstance(param, list): @@ -425,3 +445,19 @@ class Filters: params[param_id] = individual_param i += 1 return sql_bits, params + + def should_compare_with_numeric_value( + self, column_details, column, lookup, value + ): + if lookup not in ("exact", "not"): + return False + if not isinstance(value, str): + return False + if column_details is None: + return False + column_detail = column_details.get(column) + if column_detail is None: + return False + if (column_detail.type or "").strip(): + return False + return bool(looks_like_number_re.match(value)) diff --git a/datasette/views/table.py b/datasette/views/table.py index da69c6b5..d8c95c15 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -1159,7 +1159,9 @@ async def table_view_data( # Introspect columns and primary keys for table pks = await db.primary_keys(table_name) - table_columns = await db.table_columns(table_name) + table_column_details = await db.table_column_details(table_name) + table_columns = [col.name for col in table_column_details] + column_details = {col.name: col for col in table_column_details} # Take ?_col= and ?_nocol= into account specified_columns = await _columns_to_select(table_columns, pks, request) @@ -1203,7 +1205,9 @@ async def table_view_data( # Build where clauses from query string arguments filters = Filters(sorted(filter_args)) - where_clauses, params = filters.build_where_clauses(table_name) + where_clauses, params = filters.build_where_clauses( + table_name, column_details=column_details + ) # Execute filters_from_request plugin hooks - including the default # ones that live in datasette/filters.py diff --git a/tests/test_table_api.py b/tests/test_table_api.py index ceeb646d..a30ff270 100644 --- a/tests/test_table_api.py +++ b/tests/test_table_api.py @@ -560,6 +560,67 @@ async def test_table_filter_queries_multiple_of_same_type(ds_client): ] == response.json()["rows"] +@pytest.mark.asyncio +async def test_table_filter_view_on_numeric_computed_column(bare_ds): + db = bare_ds.add_memory_database("computed_column_filter") + await db.execute_write_script(""" + CREATE TABLE items( + id INTEGER PRIMARY KEY, + category TEXT, + valid INTEGER + ); + + INSERT INTO items VALUES (1, 'a', 0); + INSERT INTO items VALUES (2, 'a', 1); + INSERT INTO items VALUES (3, 'a', 0); + INSERT INTO items VALUES (4, 'b', 0); + INSERT INTO items VALUES (5, 'b', 0); + + CREATE VIEW summary AS + SELECT category, + SUM(CASE WHEN valid THEN 1 ELSE 0 END) AS valid_count, + SUM(CASE WHEN NOT valid THEN 1 ELSE 0 END) AS invalid_count + FROM items + GROUP BY category; + """) + response = await bare_ds.client.get( + "/computed_column_filter/summary.json?_shape=objects&valid_count__exact=0" + ) + assert response.json()["rows"] == [ + {"category": "b", "valid_count": 0, "invalid_count": 2} + ] + + response = await bare_ds.client.get( + "/computed_column_filter/summary.json?_shape=objects&valid_count__not=0" + ) + assert response.json()["rows"] == [ + {"category": "a", "valid_count": 1, "invalid_count": 2} + ] + + +@pytest.mark.asyncio +async def test_table_filter_view_on_text_computed_column_preserves_exact_text(bare_ds): + db = bare_ds.add_memory_database("computed_text_column_filter") + await db.execute_write_script(""" + CREATE TABLE items( + id INTEGER PRIMARY KEY, + code TEXT + ); + + INSERT INTO items VALUES (1, '0'); + INSERT INTO items VALUES (2, '00'); + + CREATE VIEW summary AS + SELECT code, + code || '' AS code_text + FROM items; + """) + response = await bare_ds.client.get( + "/computed_text_column_filter/summary.json?_shape=objects&code_text__exact=00" + ) + assert response.json()["rows"] == [{"code": "00", "code_text": "00"}] + + @pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") @pytest.mark.asyncio async def test_table_filter_json_arraycontains(ds_client):