If max_returned_rows==page_size, increment max_returned_rows

Fixes #230, where if the two were equal pagination didn't work correctly.
This commit is contained in:
Simon Willison 2018-04-25 21:04:12 -07:00
commit 4504d5160b
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
3 changed files with 28 additions and 6 deletions

View file

@ -176,10 +176,13 @@ class BaseView(RenderMixin):
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(sql, params or {}) cursor.execute(sql, params or {})
if self.max_returned_rows and truncate: max_returned_rows = self.max_returned_rows
rows = cursor.fetchmany(self.max_returned_rows + 1) if max_returned_rows == self.page_size:
truncated = len(rows) > self.max_returned_rows max_returned_rows += 1
rows = rows[:self.max_returned_rows] if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else: else:
rows = cursor.fetchall() rows = cursor.fetchall()
truncated = False truncated = False

View file

@ -9,7 +9,7 @@ import tempfile
import time import time
def app_client(sql_time_limit_ms=None): def app_client(sql_time_limit_ms=None, max_returned_rows=None):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, 'test_tables.db') filepath = os.path.join(tmpdir, 'test_tables.db')
conn = sqlite3.connect(filepath) conn = sqlite3.connect(filepath)
@ -21,7 +21,7 @@ def app_client(sql_time_limit_ms=None):
ds = Datasette( ds = Datasette(
[filepath], [filepath],
page_size=50, page_size=50,
max_returned_rows=100, max_returned_rows=max_returned_rows or 100,
sql_time_limit_ms=sql_time_limit_ms or 20, sql_time_limit_ms=sql_time_limit_ms or 20,
metadata=METADATA, metadata=METADATA,
plugins_dir=plugins_dir, plugins_dir=plugins_dir,
@ -38,6 +38,10 @@ def app_client_longer_time_limit():
yield from app_client(200) yield from app_client(200)
def app_client_returend_rows_matches_page_size():
yield from app_client(max_returned_rows=50)
def generate_compound_rows(num): def generate_compound_rows(num):
for a, b, c in itertools.islice( for a, b, c in itertools.islice(
itertools.product(string.ascii_lowercase, repeat=3), num itertools.product(string.ascii_lowercase, repeat=3), num

View file

@ -1,6 +1,7 @@
from .fixtures import ( from .fixtures import (
app_client, app_client,
app_client_longer_time_limit, app_client_longer_time_limit,
app_client_returend_rows_matches_page_size,
generate_compound_rows, generate_compound_rows,
generate_sortable_rows, generate_sortable_rows,
METADATA, METADATA,
@ -9,6 +10,7 @@ import pytest
pytest.fixture(scope='module')(app_client) pytest.fixture(scope='module')(app_client)
pytest.fixture(scope='module')(app_client_longer_time_limit) pytest.fixture(scope='module')(app_client_longer_time_limit)
pytest.fixture(scope='module')(app_client_returend_rows_matches_page_size)
def test_homepage(app_client): def test_homepage(app_client):
@ -691,3 +693,16 @@ def test_plugins_json(app_client):
'static': False, 'static': False,
'templates': False 'templates': False
} in response.json } in response.json
def test_page_size_matching_max_returned_rows(app_client_returend_rows_matches_page_size):
fetched = []
path = '/test_tables/no_primary_key.json'
while path:
response = app_client_returend_rows_matches_page_size.get(
path, gather_request=False
)
fetched.extend(response.json['rows'])
assert len(response.json['rows']) in (1, 50)
path = response.json['next_url']
assert 201 == len(fetched)