From ff9efa668ebc33f17ef9b30139960e29906a18fb Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 23 Jun 2019 06:50:02 -0700 Subject: [PATCH] Implemente AsgiStream, CSV tests all now pass #272 --- datasette/utils.py | 4 +-- datasette/views/base.py | 77 ++++++++++++++++++++++++++++++++--------- tests/test_csv.py | 8 ++--- 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/datasette/utils.py b/datasette/utils.py index 58746be4..5ed8dd12 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -697,13 +697,13 @@ class LimitedWriter: self.limit_bytes = limit_mb * 1024 * 1024 self.bytes_count = 0 - def write(self, bytes): + async def write(self, bytes): self.bytes_count += len(bytes) if self.limit_bytes and (self.bytes_count > self.limit_bytes): raise WriteLimitExceeded( "CSV contains more than {} bytes".format(self.limit_bytes) ) - self.writer.write(bytes) + await self.writer.write(bytes) _infinities = {float("inf"), float("-inf")} diff --git a/datasette/views/base.py b/datasette/views/base.py index 96bc996a..0b02a13b 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -129,17 +129,20 @@ class AsgiView(HTTPMethodView): response = await self.dispatch_request( request, **scope["url_route"]["kwargs"] ) - await send( - { - "type": "http.response.start", - "status": response.status, - "headers": [ - [key.encode("utf-8"), value.encode("utf-8")] - for key, value in response.headers.items() - ], - } - ) - await send({"type": "http.response.body", "body": response.body}) + if hasattr(response, "asgi_send"): + await response.asgi_send(send) + else: + await send( + { + "type": "http.response.start", + "status": response.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in response.headers.items() + ], + } + ) + await send({"type": "http.response.body", "body": response.body}) view.view_class = cls view.__doc__ = cls.__doc__ @@ -148,6 +151,48 @@ class AsgiView(HTTPMethodView): return view +class AsgiStream: + def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): + self.stream_fn = stream_fn + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + # Remove any existing content-type header + headers = dict( + [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"] + ) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + w = AsgiWriter(send) + await self.stream_fn(w) + await send({"type": "http.response.body", "body": b""}) + + +class AsgiWriter: + def __init__(self, send): + self.send = send + + async def write(self, chunk): + await self.send( + { + "type": "http.response.body", + "body": chunk.encode("utf8"), + "more_body": True, + } + ) + + class BaseView(AsgiView): ds = None @@ -383,13 +428,13 @@ class DataView(BaseView): if not first: data, _, _ = await self.data(request, database, hash, **kwargs) if first: - writer.writerow(headings) + await writer.writerow(headings) first = False next = data.get("next") for row in data["rows"]: if not expanded_columns: # Simple path - writer.writerow(row) + await writer.writerow(row) else: # Look for {"value": "label": } dicts and expand new_row = [] @@ -399,10 +444,10 @@ class DataView(BaseView): new_row.append(cell["label"]) else: new_row.append(cell) - writer.writerow(new_row) + await writer.writerow(new_row) except Exception as e: print("caught this", e) - r.write(str(e)) + await r.write(str(e)) return content_type = "text/plain; charset=utf-8" @@ -416,7 +461,7 @@ class DataView(BaseView): ) headers["Content-Disposition"] = disposition - return response.stream(stream_fn, headers=headers, content_type=content_type) + return AsgiStream(stream_fn, headers=headers, content_type=content_type) async def get_format(self, request, database, args): """ Determine the format of the response from the request, from URL diff --git a/tests/test_csv.py b/tests/test_csv.py index cf0e6732..c3cdc241 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -46,7 +46,7 @@ def test_table_csv(app_client): response = app_client.get("/fixtures/simple_primary_key.csv") assert response.status == 200 assert not response.headers.get("Access-Control-Allow-Origin") - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_CSV == response.text @@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors): def test_table_csv_with_labels(app_client): response = app_client.get("/fixtures/facetable.csv?_labels=1") assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text @@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client): "/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2" ) assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_CUSTOM_CSV == response.text def test_table_csv_download(app_client): response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1") assert response.status == 200 - assert "text/csv; charset=utf-8" == response.headers["Content-Type"] + assert "text/csv; charset=utf-8" == response.headers["content-type"] expected_disposition = 'attachment; filename="simple_primary_key.csv"' assert expected_disposition == response.headers["Content-Disposition"]