diff --git a/datasette/app.py b/datasette/app.py index 0560193e..bd398234 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -647,7 +647,9 @@ class TableView(RowTableShared): forward_querystring=False ) - filters = Filters(sorted(other_args.items())) + units = self.table_metadata(name, table).get('units', {}) + + filters = Filters(sorted(other_args.items()), units, ureg) where_clauses, params = filters.build_where_clauses() # _search support: @@ -891,7 +893,7 @@ class TableView(RowTableShared): 'filtered_table_rows_count': filtered_table_rows_count, 'columns': columns, 'primary_keys': pks, - 'units': self.table_metadata(name, table).get('units', {}), + 'units': units, 'query': { 'sql': sql, 'params': params, diff --git a/datasette/utils.py b/datasette/utils.py index cd1f08cf..b5020be2 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -10,6 +10,7 @@ import tempfile import time import shutil import urllib +import numbers # From https://www.sqlite.org/lang_keywords.html @@ -459,8 +460,10 @@ class Filters: f.key: f for f in _filters } - def __init__(self, pairs): + def __init__(self, pairs, units={}, ureg=None): self.pairs = pairs + self.units = units + self.ureg = ureg def lookups(self): "Yields (lookup, display, no_argument) pairs" @@ -500,13 +503,27 @@ class Filters: def has_selections(self): return bool(self.pairs) + def convert_unit(self, column, value): + "If the user has provided a unit in the quey, convert it into the column unit, if present." + if column not in self.units: + return value + + # Try to interpret the value as a unit + value = self.ureg(value) + if isinstance(value, numbers.Number): + # It's just a bare number, assume it's the column unit + return value + + column_unit = self.ureg(self.units[column]) + return value.to(column_unit).magnitude + def build_where_clauses(self): sql_bits = [] params = {} for i, (column, lookup, value) in enumerate(self.selections()): filter = self._filters_by_key.get(lookup, None) if filter: - sql_bit, param = filter.where_clause(column, value, i) + sql_bit, param = filter.where_clause(column, self.convert_unit(column, value), i) sql_bits.append(sql_bit) if param is not None: param_id = 'p{}'.format(i)