mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
New config option max_csv_mb limiting size of CSV export - refs #266
This commit is contained in:
parent
619a9ddb33
commit
9d00718250
8 changed files with 89 additions and 29 deletions
|
|
@ -97,6 +97,9 @@ CONFIG_OPTIONS = (
|
|||
ConfigOption("allow_csv_stream", True, """
|
||||
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)
|
||||
""".strip()),
|
||||
ConfigOption("max_csv_mb", 100, """
|
||||
Maximum size allowed for CSV export in MB. Set 0 to disable this limit.
|
||||
""".strip()),
|
||||
)
|
||||
DEFAULT_CONFIG = {
|
||||
option.name: option.default
|
||||
|
|
|
|||
|
|
@ -832,3 +832,22 @@ def value_as_boolean(value):
|
|||
|
||||
class ValueAsBooleanError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class WriteLimitExceeded(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LimitedWriter:
|
||||
def __init__(self, writer, limit_mb):
|
||||
self.writer = writer
|
||||
self.limit_bytes = limit_mb * 1024 * 1024
|
||||
self.bytes_count = 0
|
||||
|
||||
def write(self, bytes):
|
||||
self.bytes_count += len(bytes)
|
||||
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
|
||||
raise WriteLimitExceeded("CSV contains more than {} bytes".format(
|
||||
self.limit_bytes
|
||||
))
|
||||
self.writer.write(bytes)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from datasette.utils import (
|
|||
CustomJSONEncoder,
|
||||
InterruptedError,
|
||||
InvalidSql,
|
||||
LimitedWriter,
|
||||
path_from_row_pks,
|
||||
path_with_added_args,
|
||||
path_with_format,
|
||||
|
|
@ -191,34 +192,39 @@ class BaseView(RenderMixin):
|
|||
|
||||
async def stream_fn(r):
|
||||
nonlocal data
|
||||
writer = csv.writer(r)
|
||||
writer = csv.writer(LimitedWriter(r, self.ds.config["max_csv_mb"]))
|
||||
first = True
|
||||
next = None
|
||||
while first or (next and stream):
|
||||
if next:
|
||||
kwargs["_next"] = next
|
||||
if not first:
|
||||
data, extra_template_data, templates = await self.data(
|
||||
request, name, hash, **kwargs
|
||||
)
|
||||
if first:
|
||||
writer.writerow(headings)
|
||||
first = False
|
||||
next = data.get("next")
|
||||
for row in data["rows"]:
|
||||
if not expanded_columns:
|
||||
# Simple path
|
||||
writer.writerow(row)
|
||||
else:
|
||||
# Look for {"value": "label": } dicts and expand
|
||||
new_row = []
|
||||
for cell in row:
|
||||
if isinstance(cell, dict):
|
||||
new_row.append(cell["value"])
|
||||
new_row.append(cell["label"])
|
||||
else:
|
||||
new_row.append(cell)
|
||||
writer.writerow(new_row)
|
||||
try:
|
||||
if next:
|
||||
kwargs["_next"] = next
|
||||
if not first:
|
||||
data, extra_template_data, templates = await self.data(
|
||||
request, name, hash, **kwargs
|
||||
)
|
||||
if first:
|
||||
writer.writerow(headings)
|
||||
first = False
|
||||
next = data.get("next")
|
||||
for row in data["rows"]:
|
||||
if not expanded_columns:
|
||||
# Simple path
|
||||
writer.writerow(row)
|
||||
else:
|
||||
# Look for {"value": "label": } dicts and expand
|
||||
new_row = []
|
||||
for cell in row:
|
||||
if isinstance(cell, dict):
|
||||
new_row.append(cell["value"])
|
||||
new_row.append(cell["label"])
|
||||
else:
|
||||
new_row.append(cell)
|
||||
writer.writerow(new_row)
|
||||
except Exception as e:
|
||||
print('caught this', e)
|
||||
r.write(str(e))
|
||||
return
|
||||
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
headers = {}
|
||||
|
|
@ -417,7 +423,8 @@ class BaseView(RenderMixin):
|
|||
return r
|
||||
|
||||
async def custom_sql(
|
||||
self, request, name, hash, sql, editable=True, canned_query=None
|
||||
self, request, name, hash, sql, editable=True, canned_query=None,
|
||||
_size=None
|
||||
):
|
||||
params = request.raw_args
|
||||
if "sql" in params:
|
||||
|
|
@ -439,6 +446,8 @@ class BaseView(RenderMixin):
|
|||
extra_args = {}
|
||||
if params.get("_timelimit"):
|
||||
extra_args["custom_time_limit"] = int(params["_timelimit"])
|
||||
if _size:
|
||||
extra_args["page_size"] = _size
|
||||
results = await self.ds.execute(
|
||||
name, sql, params, truncate=True, **extra_args
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
|
|||
|
||||
class DatabaseView(BaseView):
|
||||
|
||||
async def data(self, request, name, hash, default_labels=False):
|
||||
async def data(self, request, name, hash, default_labels=False, _size=None):
|
||||
if request.args.get("sql"):
|
||||
if not self.ds.config["allow_sql"]:
|
||||
raise DatasetteError("sql= is not allowed", status=400)
|
||||
sql = request.raw_args.pop("sql")
|
||||
validate_sql_select(sql)
|
||||
return await self.custom_sql(request, name, hash, sql)
|
||||
return await self.custom_sql(request, name, hash, sql, _size=_size)
|
||||
|
||||
info = self.ds.inspect()[name]
|
||||
metadata = self.ds.metadata.get("databases", {}).get(name, {})
|
||||
|
|
|
|||
|
|
@ -137,3 +137,12 @@ can turn it off like this::
|
|||
::
|
||||
|
||||
datasette mydatabase.db --config allow_csv_stream:off
|
||||
|
||||
|
||||
max_csv_mb
|
||||
----------
|
||||
|
||||
The maximum size of CSV that can be exported, in megabytes. Defaults to 100MB.
|
||||
You can disable the limit entirely by settings this to 0::
|
||||
|
||||
datasette mydatabase.db --config max_csv_mb:0
|
||||
|
|
|
|||
|
|
@ -71,6 +71,13 @@ def app_client_larger_cache_size():
|
|||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def app_client_csv_max_mb_one():
|
||||
yield from app_client(config={
|
||||
'max_csv_mb': 1,
|
||||
})
|
||||
|
||||
|
||||
def generate_compound_rows(num):
|
||||
for a, b, c in itertools.islice(
|
||||
itertools.product(string.ascii_lowercase, repeat=3), num
|
||||
|
|
|
|||
|
|
@ -902,6 +902,7 @@ def test_config_json(app_client):
|
|||
"num_sql_threads": 3,
|
||||
"cache_size_kb": 0,
|
||||
"allow_csv_stream": True,
|
||||
"max_csv_mb": 100,
|
||||
} == response.json
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from .fixtures import app_client # noqa
|
||||
from .fixtures import app_client, app_client_csv_max_mb_one # noqa
|
||||
|
||||
EXPECTED_TABLE_CSV = '''id,content
|
||||
1,hello
|
||||
|
|
@ -61,6 +61,18 @@ def test_table_csv_download(app_client):
|
|||
assert expected_disposition == response.headers['Content-Disposition']
|
||||
|
||||
|
||||
def test_max_csv_mb(app_client_csv_max_mb_one):
|
||||
response = app_client_csv_max_mb_one.get(
|
||||
"/fixtures.csv?sql=select+randomblob(10000)+"
|
||||
"from+compound_three_primary_keys&_stream=1&_size=max"
|
||||
)
|
||||
# It's a 200 because we started streaming before we knew the error
|
||||
assert response.status == 200
|
||||
# Last line should be an error message
|
||||
last_line = [line for line in response.body.split(b"\r\n") if line][-1]
|
||||
assert last_line.startswith(b"CSV contains more than")
|
||||
|
||||
|
||||
def test_table_csv_stream(app_client):
|
||||
# Without _stream should return header + 100 rows:
|
||||
response = app_client.get(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue