Implemente AsgiStream, CSV tests all now pass #272

This commit is contained in:
Simon Willison 2019-06-23 06:50:02 -07:00
commit ff9efa668e
3 changed files with 67 additions and 22 deletions

View file

@ -697,13 +697,13 @@ class LimitedWriter:
self.limit_bytes = limit_mb * 1024 * 1024
self.bytes_count = 0
def write(self, bytes):
async def write(self, bytes):
self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded(
"CSV contains more than {} bytes".format(self.limit_bytes)
)
self.writer.write(bytes)
await self.writer.write(bytes)
_infinities = {float("inf"), float("-inf")}

View file

@ -129,17 +129,20 @@ class AsgiView(HTTPMethodView):
response = await self.dispatch_request(
request, **scope["url_route"]["kwargs"]
)
await send(
{
"type": "http.response.start",
"status": response.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
}
)
await send({"type": "http.response.body", "body": response.body})
if hasattr(response, "asgi_send"):
await response.asgi_send(send)
else:
await send(
{
"type": "http.response.start",
"status": response.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
}
)
await send({"type": "http.response.body", "body": response.body})
view.view_class = cls
view.__doc__ = cls.__doc__
@ -148,6 +151,48 @@ class AsgiView(HTTPMethodView):
return view
class AsgiStream:
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
self.stream_fn = stream_fn
self.status = status
self.headers = headers or {}
self.content_type = content_type
async def asgi_send(self, send):
# Remove any existing content-type header
headers = dict(
[(k, v) for k, v in self.headers.items() if k.lower() != "content-type"]
)
headers["content-type"] = self.content_type
await send(
{
"type": "http.response.start",
"status": self.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in headers.items()
],
}
)
w = AsgiWriter(send)
await self.stream_fn(w)
await send({"type": "http.response.body", "body": b""})
class AsgiWriter:
def __init__(self, send):
self.send = send
async def write(self, chunk):
await self.send(
{
"type": "http.response.body",
"body": chunk.encode("utf8"),
"more_body": True,
}
)
class BaseView(AsgiView):
ds = None
@ -383,13 +428,13 @@ class DataView(BaseView):
if not first:
data, _, _ = await self.data(request, database, hash, **kwargs)
if first:
writer.writerow(headings)
await writer.writerow(headings)
first = False
next = data.get("next")
for row in data["rows"]:
if not expanded_columns:
# Simple path
writer.writerow(row)
await writer.writerow(row)
else:
# Look for {"value": "label": } dicts and expand
new_row = []
@ -399,10 +444,10 @@ class DataView(BaseView):
new_row.append(cell["label"])
else:
new_row.append(cell)
writer.writerow(new_row)
await writer.writerow(new_row)
except Exception as e:
print("caught this", e)
r.write(str(e))
await r.write(str(e))
return
content_type = "text/plain; charset=utf-8"
@ -416,7 +461,7 @@ class DataView(BaseView):
)
headers["Content-Disposition"] = disposition
return response.stream(stream_fn, headers=headers, content_type=content_type)
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL

View file

@ -46,7 +46,7 @@ def test_table_csv(app_client):
response = app_client.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert not response.headers.get("Access-Control-Allow-Origin")
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_TABLE_CSV == response.text
@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors):
def test_table_csv_with_labels(app_client):
response = app_client.get("/fixtures/facetable.csv?_labels=1")
assert response.status == 200
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client):
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
)
assert response.status == 200
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert "text/plain; charset=utf-8" == response.headers["content-type"]
assert EXPECTED_CUSTOM_CSV == response.text
def test_table_csv_download(app_client):
response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
assert response.status == 200
assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
assert "text/csv; charset=utf-8" == response.headers["content-type"]
expected_disposition = 'attachment; filename="simple_primary_key.csv"'
assert expected_disposition == response.headers["Content-Disposition"]