diff --git a/datasette/app.py b/datasette/app.py index ce4e2e95..b5a612c2 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -20,6 +20,7 @@ from .utils import ( Filters, compound_pks_from_path, CustomJSONEncoder, + compound_keys_after_sql, detect_fts_sql, escape_css_string, escape_sqlite_table_name, @@ -592,15 +593,10 @@ class TableView(RowTableShared): else: pk_values = compound_pks_from_path(next) if len(pk_values) == len(pks): - param_counter = len(params) - for pk, value in zip(pks, pk_values): - where_clauses.append( - '"{}" > :p{}'.format( - pk, param_counter, - ) - ) - params['p{}'.format(param_counter)] = value - param_counter += 1 + where_clauses.append(compound_keys_after_sql(pks)) + param_len = len(params) + for i, pk_value in enumerate(pk_values): + params['p{}'.format(param_len + i)] = pk_value where_clause = '' if where_clauses: diff --git a/datasette/utils.py b/datasette/utils.py index 669f0e52..74bf68db 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -29,6 +29,34 @@ def path_from_row_pks(row, pks, use_rowid): return ','.join(bits) +def compound_keys_after_sql(pks): + # Implementation of keyset pagination + # See https://github.com/simonw/datasette/issues/190 + # For pk1/pk2/pk3 returns: + # + # ([pk1] > :p0) + # or + # ([pk1] = :p0 and [pk2] > :p1) + # or + # ([pk1] = :p0 and [pk2] = :p1 and [pk3] > :p2) + or_clauses = [] + pks_left = pks[:] + while pks_left: + and_clauses = [] + last = pks_left[-1] + rest = pks_left[:-1] + and_clauses = ['[{}] = :p{}'.format( + pk, i + ) for i, pk in enumerate(rest)] + and_clauses.append('[{}] > :p{}'.format( + last, len(rest) + )) + or_clauses.append('({})'.format(' and '.join(and_clauses))) + pks_left.pop() + or_clauses.reverse() + return '\n or\n'.join(or_clauses) + + class CustomJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, sqlite3.Row): diff --git a/tests/fixtures.py b/tests/fixtures.py index d5a53d8f..6dcb346b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,8 @@ from datasette.app import Datasette +import itertools import os import sqlite3 +import string import tempfile import time @@ -60,6 +62,16 @@ CREATE TABLE compound_primary_key ( INSERT INTO compound_primary_key VALUES ('a', 'b', 'c'); + +CREATE TABLE compound_three_primary_keys ( + pk1 varchar(30), + pk2 varchar(30), + pk3 varchar(30), + content text, + PRIMARY KEY (pk1, pk2, pk3) +); + + CREATE TABLE no_primary_key ( content text, a text, @@ -111,4 +123,10 @@ CREATE VIEW simple_view AS ''' + '\n'.join([ 'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(i=i + 1) for i in range(201) +]) + '\n'.join([ + 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{a}-{b}-{c}");'.format( + a=a, b=b, c=c + ) for a, b, c in itertools.islice( + itertools.product(string.ascii_lowercase, repeat=3), 301 + ) ]) diff --git a/tests/test_api.py b/tests/test_api.py index d717d9d4..d686ab14 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,7 +10,7 @@ def test_homepage(app_client): assert response.json.keys() == {'test_tables': 0}.keys() d = response.json['test_tables'] assert d['name'] == 'test_tables' - assert d['tables_count'] == 7 + assert d['tables_count'] == 8 def test_database_page(app_client): @@ -60,6 +60,13 @@ def test_database_page(app_client): 'hidden': False, 'foreign_keys': {'incoming': [], 'outgoing': []}, 'label_column': None, + }, { + 'columns': ['pk1', 'pk2', 'pk3', 'content'], + 'name': 'compound_three_primary_keys', + 'count': 301, + 'hidden': False, + 'foreign_keys': {'incoming': [], 'outgoing': []}, + 'label_column': None, }, { 'columns': ['content', 'a', 'b', 'c'], 'name': 'no_primary_key', @@ -201,6 +208,19 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag assert expected_pages == count +def test_paginate_compound_keys(app_client): + fetched = [] + path = '/test_tables/compound_three_primary_keys.jsono' + while path: + response = app_client.get(path, gather_request=False) + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert 301 == len(fetched) + # Should be correctly ordered + contents = [f['content'] for f in fetched] + assert list(sorted(contents)) == contents + + @pytest.mark.parametrize('path,expected_rows', [ ('/test_tables/simple_primary_key.json?content=hello', [ ['1', 'hello'], diff --git a/tests/test_utils.py b/tests/test_utils.py index c95fdaee..5871d39b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -224,3 +224,19 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): assert 'world' == open(hello).read() # It should be a copy, not a hard link assert 1 == os.stat(hello).st_nlink + + +def test_compound_keys_after_sql(): + assert '([a] > :p0)' == utils.compound_keys_after_sql(['a']) + assert ''' +([a] > :p0) + or +([a] = :p0 and [b] > :p1) + '''.strip() == utils.compound_keys_after_sql(['a', 'b']) + assert ''' +([a] > :p0) + or +([a] = :p0 and [b] > :p1) + or +([a] = :p0 and [b] = :p1 and [c] > :p2) + '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])