From 619a9ddb338e2b11419477a6f16c6ef5bd57d32b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 17 Jun 2018 19:31:09 -0700 Subject: [PATCH] table.csv?_stream=1 to download all rows - refs #266 This option causes Datasette to serve ALL rows in the table, by internally following the _next= pagination links and serving everything out as a stream. Also added new config option, allow_csv_stream, which can be used to disable this feature. --- datasette/app.py | 3 ++ datasette/views/base.py | 84 ++++++++++++++++++++--------------------- docs/config.rst | 12 ++++++ tests/test_api.py | 1 + tests/test_csv.py | 13 +++++++ 5 files changed, 69 insertions(+), 44 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 70f2a93f..92327353 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -94,6 +94,9 @@ CONFIG_OPTIONS = ( ConfigOption("cache_size_kb", 0, """ SQLite cache size in KB (0 == use SQLite default) """.strip()), + ConfigOption("allow_csv_stream", True, """ + Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows) + """.strip()), ) DEFAULT_CONFIG = { option.name: option.default diff --git a/datasette/views/base.py b/datasette/views/base.py index 7beff72e..a572c330 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -149,42 +149,24 @@ class BaseView(RenderMixin): return await self.view_get(request, name, hash, **kwargs) - async def as_csv_stream(self, request, name, hash, **kwargs): - assert not request.args.get("_next") # TODO: real error - kwargs['_size'] = 'max' - - async def stream_fn(r): - first = True - next = None - writer = csv.writer(r) - while first or next: - if next: - kwargs['_next'] = next - data, extra_template_data, templates = await self.data( - request, name, hash, **kwargs - ) - if first: - writer.writerow(data["columns"]) - first = False - next = data["next"] - for row in data["rows"]: - writer.writerow(row) - - return response.stream( - stream_fn, - content_type="text/plain; charset=utf-8" - ) - async def as_csv(self, request, name, hash, **kwargs): - if request.args.get("_stream"): - return await self.as_csv_stream(request, name, hash, **kwargs) + stream = request.args.get("_stream") + if stream: + # Some quick sanity checks + if not self.ds.config["allow_csv_stream"]: + raise DatasetteError("CSV streaming is disabled", status=400) + if request.args.get("_next"): + raise DatasetteError( + "_next not allowed for CSV streaming", status=400 + ) + kwargs["_size"] = "max" + # Fetch the first page try: response_or_template_contexts = await self.data( request, name, hash, **kwargs ) if isinstance(response_or_template_contexts, response.HTTPResponse): return response_or_template_contexts - else: data, extra_template_data, templates = response_or_template_contexts except (sqlite3.OperationalError, InvalidSql) as e: @@ -195,6 +177,7 @@ class BaseView(RenderMixin): except DatasetteError: raise + # Convert rows and columns to CSV headings = data["columns"] # if there are expanded_columns we need to add additional headings @@ -207,22 +190,35 @@ class BaseView(RenderMixin): headings.append("{}_label".format(column)) async def stream_fn(r): + nonlocal data writer = csv.writer(r) - writer.writerow(headings) - for row in data["rows"]: - if not expanded_columns: - # Simple path - writer.writerow(row) - else: - # Look for {"value": "label": } dicts and expand - new_row = [] - for cell in row: - if isinstance(cell, dict): - new_row.append(cell["value"]) - new_row.append(cell["label"]) - else: - new_row.append(cell) - writer.writerow(new_row) + first = True + next = None + while first or (next and stream): + if next: + kwargs["_next"] = next + if not first: + data, extra_template_data, templates = await self.data( + request, name, hash, **kwargs + ) + if first: + writer.writerow(headings) + first = False + next = data.get("next") + for row in data["rows"]: + if not expanded_columns: + # Simple path + writer.writerow(row) + else: + # Look for {"value": "label": } dicts and expand + new_row = [] + for cell in row: + if isinstance(cell, dict): + new_row.append(cell["value"]) + new_row.append(cell["label"]) + else: + new_row.append(cell) + writer.writerow(new_row) content_type = "text/plain; charset=utf-8" headers = {} diff --git a/docs/config.rst b/docs/config.rst index 8f0cd246..36a046fc 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -125,3 +125,15 @@ Sets the amount of memory SQLite uses for its `per-connection cache