diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index 55eba1bb..93084365 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -155,6 +155,10 @@ class Request: body = await self.post_body() return dict(parse_qsl(body.decode("utf-8"), keep_blank_values=True)) + async def json(self): + body = await self.post_body() + return json.loads(body) + async def form( self, files: bool = False, diff --git a/datasette/views/database.py b/datasette/views/database.py index 6afd9734..cd9565c6 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1093,9 +1093,8 @@ class TableCreateView(BaseView): ): return _error(["Permission denied"], 403) - body = await request.post_body() try: - data = json.loads(body) + data = await request.json() except json.JSONDecodeError as e: return _error(["Invalid JSON: {}".format(e)]) diff --git a/datasette/views/row.py b/datasette/views/row.py index e6eaa92b..4d61eb91 100644 --- a/datasette/views/row.py +++ b/datasette/views/row.py @@ -418,9 +418,8 @@ class RowUpdateView(BaseView): if not ok: return resolved - body = await request.post_body() try: - data = json.loads(body) + data = await request.json() except json.JSONDecodeError as e: return _error(["Invalid JSON: {}".format(e)]) diff --git a/datasette/views/table.py b/datasette/views/table.py index de7b0216..c5448c85 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -673,9 +673,8 @@ class TableInsertView(BaseView): if not request.headers.get("content-type").startswith("application/json"): # TODO: handle form-encoded data return _errors(["Invalid content-type, must be application/json"]) - body = await request.post_body() try: - data = json.loads(body) + data = await request.json() except json.JSONDecodeError as e: return _errors(["Invalid JSON: {}".format(e)]) if not isinstance(data, dict): @@ -974,7 +973,7 @@ class TableSetColumnTypeView(BaseView): return _error(["Invalid content-type, must be application/json"], 400) try: - data = json.loads(await request.post_body()) + data = await request.json() except json.JSONDecodeError as e: return _error(["Invalid JSON: {}".format(e)], 400) @@ -1091,7 +1090,7 @@ class TableDropView(BaseView): return _error(["Database is immutable"], 403) confirm = False try: - data = json.loads(await request.post_body()) + data = await request.json() confirm = data.get("confirm") except json.JSONDecodeError: pass diff --git a/docs/internals.rst b/docs/internals.rst index 641286f8..06027f42 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -106,6 +106,9 @@ The object also has the following awaitable methods: ``await request.post_vars()`` - dictionary Returns a dictionary of form variables that were submitted in the request body via ``POST`` using ``application/x-www-form-urlencoded`` encoding. For multipart forms or file uploads, use ``request.form()`` instead. +``await request.json()`` - Any + Returns the parsed JSON body of a request submitted by ``POST``. + ``await request.post_body()`` - bytes Returns the un-parsed body of a request submitted by ``POST`` - useful for things like incoming JSON data. diff --git a/tests/test_internals_request.py b/tests/test_internals_request.py index d1ca1f46..9c448186 100644 --- a/tests/test_internals_request.py +++ b/tests/test_internals_request.py @@ -55,6 +55,57 @@ async def test_request_post_body(): assert data == json.loads(body) +@pytest.mark.asyncio +async def test_request_json(): + scope = { + "http_version": "1.1", + "method": "POST", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "scheme": "http", + "type": "http", + "headers": [[b"content-type", b"application/json"]], + } + + data = {"hello": "world", "items": [1, 2, 3]} + + async def receive(): + return { + "type": "http.request", + "body": json.dumps(data).encode("utf-8"), + "more_body": False, + } + + request = Request(scope, receive) + assert data == await request.json() + + +@pytest.mark.asyncio +async def test_request_json_invalid(): + scope = { + "http_version": "1.1", + "method": "POST", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "scheme": "http", + "type": "http", + "headers": [[b"content-type", b"application/json"]], + } + + async def receive(): + return { + "type": "http.request", + "body": b"this is not JSON", + "more_body": False, + } + + request = Request(scope, receive) + with pytest.raises(json.JSONDecodeError): + await request.json() + + def test_request_args(): request = Request.fake("/foo?multi=1&multi=2&single=3") assert "1" == request.args.get("multi")