From 9d0071825077b2ebb77314c5a05bd614f3b28b54 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 17 Jun 2018 20:01:30 -0700 Subject: [PATCH] New config option max_csv_mb limiting size of CSV export - refs #266 --- datasette/app.py | 3 ++ datasette/utils.py | 19 ++++++++++++ datasette/views/base.py | 61 +++++++++++++++++++++---------------- datasette/views/database.py | 4 +-- docs/config.rst | 9 ++++++ tests/fixtures.py | 7 +++++ tests/test_api.py | 1 + tests/test_csv.py | 14 ++++++++- 8 files changed, 89 insertions(+), 29 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 92327353..fb389d73 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -97,6 +97,9 @@ CONFIG_OPTIONS = ( ConfigOption("allow_csv_stream", True, """ Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows) """.strip()), + ConfigOption("max_csv_mb", 100, """ + Maximum size allowed for CSV export in MB. Set 0 to disable this limit. + """.strip()), ) DEFAULT_CONFIG = { option.name: option.default diff --git a/datasette/utils.py b/datasette/utils.py index a179eddf..005db87f 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -832,3 +832,22 @@ def value_as_boolean(value): class ValueAsBooleanError(ValueError): pass + + +class WriteLimitExceeded(Exception): + pass + + +class LimitedWriter: + def __init__(self, writer, limit_mb): + self.writer = writer + self.limit_bytes = limit_mb * 1024 * 1024 + self.bytes_count = 0 + + def write(self, bytes): + self.bytes_count += len(bytes) + if self.limit_bytes and (self.bytes_count > self.limit_bytes): + raise WriteLimitExceeded("CSV contains more than {} bytes".format( + self.limit_bytes + )) + self.writer.write(bytes) diff --git a/datasette/views/base.py b/datasette/views/base.py index a572c330..0ca52e61 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -16,6 +16,7 @@ from datasette.utils import ( CustomJSONEncoder, InterruptedError, InvalidSql, + LimitedWriter, path_from_row_pks, path_with_added_args, path_with_format, @@ -191,34 +192,39 @@ class BaseView(RenderMixin): async def stream_fn(r): nonlocal data - writer = csv.writer(r) + writer = csv.writer(LimitedWriter(r, self.ds.config["max_csv_mb"])) first = True next = None while first or (next and stream): - if next: - kwargs["_next"] = next - if not first: - data, extra_template_data, templates = await self.data( - request, name, hash, **kwargs - ) - if first: - writer.writerow(headings) - first = False - next = data.get("next") - for row in data["rows"]: - if not expanded_columns: - # Simple path - writer.writerow(row) - else: - # Look for {"value": "label": } dicts and expand - new_row = [] - for cell in row: - if isinstance(cell, dict): - new_row.append(cell["value"]) - new_row.append(cell["label"]) - else: - new_row.append(cell) - writer.writerow(new_row) + try: + if next: + kwargs["_next"] = next + if not first: + data, extra_template_data, templates = await self.data( + request, name, hash, **kwargs + ) + if first: + writer.writerow(headings) + first = False + next = data.get("next") + for row in data["rows"]: + if not expanded_columns: + # Simple path + writer.writerow(row) + else: + # Look for {"value": "label": } dicts and expand + new_row = [] + for cell in row: + if isinstance(cell, dict): + new_row.append(cell["value"]) + new_row.append(cell["label"]) + else: + new_row.append(cell) + writer.writerow(new_row) + except Exception as e: + print('caught this', e) + r.write(str(e)) + return content_type = "text/plain; charset=utf-8" headers = {} @@ -417,7 +423,8 @@ class BaseView(RenderMixin): return r async def custom_sql( - self, request, name, hash, sql, editable=True, canned_query=None + self, request, name, hash, sql, editable=True, canned_query=None, + _size=None ): params = request.raw_args if "sql" in params: @@ -439,6 +446,8 @@ class BaseView(RenderMixin): extra_args = {} if params.get("_timelimit"): extra_args["custom_time_limit"] = int(params["_timelimit"]) + if _size: + extra_args["page_size"] = _size results = await self.ds.execute( name, sql, params, truncate=True, **extra_args ) diff --git a/datasette/views/database.py b/datasette/views/database.py index 2f3f41d3..a7df485b 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError class DatabaseView(BaseView): - async def data(self, request, name, hash, default_labels=False): + async def data(self, request, name, hash, default_labels=False, _size=None): if request.args.get("sql"): if not self.ds.config["allow_sql"]: raise DatasetteError("sql= is not allowed", status=400) sql = request.raw_args.pop("sql") validate_sql_select(sql) - return await self.custom_sql(request, name, hash, sql) + return await self.custom_sql(request, name, hash, sql, _size=_size) info = self.ds.inspect()[name] metadata = self.ds.metadata.get("databases", {}).get(name, {}) diff --git a/docs/config.rst b/docs/config.rst index 36a046fc..e0013bf0 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -137,3 +137,12 @@ can turn it off like this:: :: datasette mydatabase.db --config allow_csv_stream:off + + +max_csv_mb +---------- + +The maximum size of CSV that can be exported, in megabytes. Defaults to 100MB. +You can disable the limit entirely by settings this to 0:: + + datasette mydatabase.db --config max_csv_mb:0 diff --git a/tests/fixtures.py b/tests/fixtures.py index 92fd5073..ea0b4e35 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -71,6 +71,13 @@ def app_client_larger_cache_size(): }) +@pytest.fixture(scope='session') +def app_client_csv_max_mb_one(): + yield from app_client(config={ + 'max_csv_mb': 1, + }) + + def generate_compound_rows(num): for a, b, c in itertools.islice( itertools.product(string.ascii_lowercase, repeat=3), num diff --git a/tests/test_api.py b/tests/test_api.py index 1e889963..2187deb5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -902,6 +902,7 @@ def test_config_json(app_client): "num_sql_threads": 3, "cache_size_kb": 0, "allow_csv_stream": True, + "max_csv_mb": 100, } == response.json diff --git a/tests/test_csv.py b/tests/test_csv.py index 65bc9c1d..fcbacc76 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -1,4 +1,4 @@ -from .fixtures import app_client # noqa +from .fixtures import app_client, app_client_csv_max_mb_one # noqa EXPECTED_TABLE_CSV = '''id,content 1,hello @@ -61,6 +61,18 @@ def test_table_csv_download(app_client): assert expected_disposition == response.headers['Content-Disposition'] +def test_max_csv_mb(app_client_csv_max_mb_one): + response = app_client_csv_max_mb_one.get( + "/fixtures.csv?sql=select+randomblob(10000)+" + "from+compound_three_primary_keys&_stream=1&_size=max" + ) + # It's a 200 because we started streaming before we knew the error + assert response.status == 200 + # Last line should be an error message + last_line = [line for line in response.body.split(b"\r\n") if line][-1] + assert last_line.startswith(b"CSV contains more than") + + def test_table_csv_stream(app_client): # Without _stream should return header + 100 rows: response = app_client.get(