diff --git a/datasette/app.py b/datasette/app.py index b5a612c2..ef3fde93 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -593,8 +593,8 @@ class TableView(RowTableShared): else: pk_values = compound_pks_from_path(next) if len(pk_values) == len(pks): - where_clauses.append(compound_keys_after_sql(pks)) param_len = len(params) + where_clauses.append(compound_keys_after_sql(pks, param_len)) for i, pk_value in enumerate(pk_values): params['p{}'.format(param_len + i)] = pk_value diff --git a/datasette/utils.py b/datasette/utils.py index 74bf68db..fb132f08 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -29,7 +29,7 @@ def path_from_row_pks(row, pks, use_rowid): return ','.join(bits) -def compound_keys_after_sql(pks): +def compound_keys_after_sql(pks, start_index=0): # Implementation of keyset pagination # See https://github.com/simonw/datasette/issues/190 # For pk1/pk2/pk3 returns: @@ -46,15 +46,15 @@ def compound_keys_after_sql(pks): last = pks_left[-1] rest = pks_left[:-1] and_clauses = ['[{}] = :p{}'.format( - pk, i + pk, (i + start_index) ) for i, pk in enumerate(rest)] and_clauses.append('[{}] > :p{}'.format( - last, len(rest) + last, (len(rest) + start_index) )) or_clauses.append('({})'.format(' and '.join(and_clauses))) pks_left.pop() or_clauses.reverse() - return '\n or\n'.join(or_clauses) + return '({})'.format('\n or\n'.join(or_clauses)) class CustomJSONEncoder(json.JSONEncoder): diff --git a/tests/fixtures.py b/tests/fixtures.py index 6dcb346b..3d22c289 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -26,6 +26,13 @@ def app_client(): yield ds.app().test_client +def generate_compound_rows(num): + for a, b, c in itertools.islice( + itertools.product(string.ascii_lowercase, repeat=3), num + ): + yield a, b, c, '{}-{}-{}'.format(a, b, c) + + METADATA = { 'title': 'Datasette Title', 'description': 'Datasette Description', @@ -124,9 +131,7 @@ CREATE VIEW simple_view AS 'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(i=i + 1) for i in range(201) ]) + '\n'.join([ - 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{a}-{b}-{c}");'.format( - a=a, b=b, c=c - ) for a, b, c in itertools.islice( - itertools.product(string.ascii_lowercase, repeat=3), 301 - ) + 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format( + a=a, b=b, c=c, content=content + ) for a, b, c, content in generate_compound_rows(1001) ]) diff --git a/tests/test_api.py b/tests/test_api.py index d686ab14..cff7dadc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,7 @@ -from .fixtures import app_client +from .fixtures import ( + app_client, + generate_compound_rows, +) import pytest pytest.fixture(scope='module')(app_client) @@ -63,7 +66,7 @@ def test_database_page(app_client): }, { 'columns': ['pk1', 'pk2', 'pk3', 'content'], 'name': 'compound_three_primary_keys', - 'count': 301, + 'count': 1001, 'hidden': False, 'foreign_keys': {'incoming': [], 'outgoing': []}, 'label_column': None, @@ -211,14 +214,37 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag def test_paginate_compound_keys(app_client): fetched = [] path = '/test_tables/compound_three_primary_keys.jsono' + page = 0 while path: + page += 1 response = app_client.get(path, gather_request=False) fetched.extend(response.json['rows']) path = response.json['next_url'] - assert 301 == len(fetched) + assert page < 100 + assert 1001 == len(fetched) + assert 21 == page # Should be correctly ordered contents = [f['content'] for f in fetched] - assert list(sorted(contents)) == contents + expected = [r[3] for r in generate_compound_rows(1001)] + assert expected == contents + + +def test_paginate_compound_keys_with_extra_filters(app_client): + fetched = [] + path = '/test_tables/compound_three_primary_keys.jsono?content__contains=d' + page = 0 + while path: + page += 1 + assert page < 100 + response = app_client.get(path, gather_request=False) + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert 2 == page + expected = [ + r[3] for r in generate_compound_rows(1001) + if 'd' in r[3] + ] + assert expected == [f['content'] for f in fetched] @pytest.mark.parametrize('path,expected_rows', [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 5871d39b..73476d7c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -227,16 +227,16 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): def test_compound_keys_after_sql(): - assert '([a] > :p0)' == utils.compound_keys_after_sql(['a']) + assert '(([a] > :p0))' == utils.compound_keys_after_sql(['a']) assert ''' -([a] > :p0) +(([a] > :p0) or -([a] = :p0 and [b] > :p1) +([a] = :p0 and [b] > :p1)) '''.strip() == utils.compound_keys_after_sql(['a', 'b']) assert ''' -([a] > :p0) +(([a] > :p0) or ([a] = :p0 and [b] > :p1) or -([a] = :p0 and [b] = :p1 and [c] > :p2) +([a] = :p0 and [b] = :p1 and [c] > :p2)) '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])