New config option max_csv_mb limiting size of CSV export - refs #266

This commit is contained in:
Simon Willison 2018-06-17 20:01:30 -07:00
commit 9d00718250
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
8 changed files with 89 additions and 29 deletions

View file

@ -97,6 +97,9 @@ CONFIG_OPTIONS = (
ConfigOption("allow_csv_stream", True, """ ConfigOption("allow_csv_stream", True, """
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows) Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)
""".strip()), """.strip()),
ConfigOption("max_csv_mb", 100, """
Maximum size allowed for CSV export in MB. Set 0 to disable this limit.
""".strip()),
) )
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
option.name: option.default option.name: option.default

View file

@ -832,3 +832,22 @@ def value_as_boolean(value):
class ValueAsBooleanError(ValueError): class ValueAsBooleanError(ValueError):
pass pass
class WriteLimitExceeded(Exception):
pass
class LimitedWriter:
def __init__(self, writer, limit_mb):
self.writer = writer
self.limit_bytes = limit_mb * 1024 * 1024
self.bytes_count = 0
def write(self, bytes):
self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded("CSV contains more than {} bytes".format(
self.limit_bytes
))
self.writer.write(bytes)

View file

@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder, CustomJSONEncoder,
InterruptedError, InterruptedError,
InvalidSql, InvalidSql,
LimitedWriter,
path_from_row_pks, path_from_row_pks,
path_with_added_args, path_with_added_args,
path_with_format, path_with_format,
@ -191,34 +192,39 @@ class BaseView(RenderMixin):
async def stream_fn(r): async def stream_fn(r):
nonlocal data nonlocal data
writer = csv.writer(r) writer = csv.writer(LimitedWriter(r, self.ds.config["max_csv_mb"]))
first = True first = True
next = None next = None
while first or (next and stream): while first or (next and stream):
if next: try:
kwargs["_next"] = next if next:
if not first: kwargs["_next"] = next
data, extra_template_data, templates = await self.data( if not first:
request, name, hash, **kwargs data, extra_template_data, templates = await self.data(
) request, name, hash, **kwargs
if first: )
writer.writerow(headings) if first:
first = False writer.writerow(headings)
next = data.get("next") first = False
for row in data["rows"]: next = data.get("next")
if not expanded_columns: for row in data["rows"]:
# Simple path if not expanded_columns:
writer.writerow(row) # Simple path
else: writer.writerow(row)
# Look for {"value": "label": } dicts and expand else:
new_row = [] # Look for {"value": "label": } dicts and expand
for cell in row: new_row = []
if isinstance(cell, dict): for cell in row:
new_row.append(cell["value"]) if isinstance(cell, dict):
new_row.append(cell["label"]) new_row.append(cell["value"])
else: new_row.append(cell["label"])
new_row.append(cell) else:
writer.writerow(new_row) new_row.append(cell)
writer.writerow(new_row)
except Exception as e:
print('caught this', e)
r.write(str(e))
return
content_type = "text/plain; charset=utf-8" content_type = "text/plain; charset=utf-8"
headers = {} headers = {}
@ -417,7 +423,8 @@ class BaseView(RenderMixin):
return r return r
async def custom_sql( async def custom_sql(
self, request, name, hash, sql, editable=True, canned_query=None self, request, name, hash, sql, editable=True, canned_query=None,
_size=None
): ):
params = request.raw_args params = request.raw_args
if "sql" in params: if "sql" in params:
@ -439,6 +446,8 @@ class BaseView(RenderMixin):
extra_args = {} extra_args = {}
if params.get("_timelimit"): if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"]) extra_args["custom_time_limit"] = int(params["_timelimit"])
if _size:
extra_args["page_size"] = _size
results = await self.ds.execute( results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args name, sql, params, truncate=True, **extra_args
) )

View file

@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView): class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False): async def data(self, request, name, hash, default_labels=False, _size=None):
if request.args.get("sql"): if request.args.get("sql"):
if not self.ds.config["allow_sql"]: if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400) raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql") sql = request.raw_args.pop("sql")
validate_sql_select(sql) validate_sql_select(sql)
return await self.custom_sql(request, name, hash, sql) return await self.custom_sql(request, name, hash, sql, _size=_size)
info = self.ds.inspect()[name] info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {}) metadata = self.ds.metadata.get("databases", {}).get(name, {})

View file

@ -137,3 +137,12 @@ can turn it off like this::
:: ::
datasette mydatabase.db --config allow_csv_stream:off datasette mydatabase.db --config allow_csv_stream:off
max_csv_mb
----------
The maximum size of CSV that can be exported, in megabytes. Defaults to 100MB.
You can disable the limit entirely by settings this to 0::
datasette mydatabase.db --config max_csv_mb:0

View file

@ -71,6 +71,13 @@ def app_client_larger_cache_size():
}) })
@pytest.fixture(scope='session')
def app_client_csv_max_mb_one():
yield from app_client(config={
'max_csv_mb': 1,
})
def generate_compound_rows(num): def generate_compound_rows(num):
for a, b, c in itertools.islice( for a, b, c in itertools.islice(
itertools.product(string.ascii_lowercase, repeat=3), num itertools.product(string.ascii_lowercase, repeat=3), num

View file

@ -902,6 +902,7 @@ def test_config_json(app_client):
"num_sql_threads": 3, "num_sql_threads": 3,
"cache_size_kb": 0, "cache_size_kb": 0,
"allow_csv_stream": True, "allow_csv_stream": True,
"max_csv_mb": 100,
} == response.json } == response.json

View file

@ -1,4 +1,4 @@
from .fixtures import app_client # noqa from .fixtures import app_client, app_client_csv_max_mb_one # noqa
EXPECTED_TABLE_CSV = '''id,content EXPECTED_TABLE_CSV = '''id,content
1,hello 1,hello
@ -61,6 +61,18 @@ def test_table_csv_download(app_client):
assert expected_disposition == response.headers['Content-Disposition'] assert expected_disposition == response.headers['Content-Disposition']
def test_max_csv_mb(app_client_csv_max_mb_one):
response = app_client_csv_max_mb_one.get(
"/fixtures.csv?sql=select+randomblob(10000)+"
"from+compound_three_primary_keys&_stream=1&_size=max"
)
# It's a 200 because we started streaming before we knew the error
assert response.status == 200
# Last line should be an error message
last_line = [line for line in response.body.split(b"\r\n") if line][-1]
assert last_line.startswith(b"CSV contains more than")
def test_table_csv_stream(app_client): def test_table_csv_stream(app_client):
# Without _stream should return header + 100 rows: # Without _stream should return header + 100 rows:
response = app_client.get( response = app_client.get(