diff --git a/datasette/app.py b/datasette/app.py index f9bf91a8..54cf02f8 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,4 +1,5 @@ import asyncio +import asgi_csrf import collections import datetime import hashlib @@ -884,7 +885,14 @@ class Datasette: await database.table_counts(limit=60 * 60 * 1000) asgi = AsgiLifespan( - AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db + AsgiTracer( + asgi_csrf.asgi_csrf( + DatasetteRouter(self, routes), + signing_secret=self._secret, + cookie_name="ds_csrftoken", + ) + ), + on_startup=setup_db, ) for wrapper in pm.hook.asgi_wrapper(datasette=self): asgi = wrapper(asgi) diff --git a/datasette/templates/messages_debug.html b/datasette/templates/messages_debug.html index b2e1bc7c..e83d2a2f 100644 --- a/datasette/templates/messages_debug.html +++ b/datasette/templates/messages_debug.html @@ -8,7 +8,7 @@

Set a message:

-
+
@@ -19,6 +19,7 @@
+
diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 52896e96..a7cb6647 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -52,6 +52,7 @@ {% endif %}

+ {% if canned_query %}{% endif %}

diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 69e288e6..059db184 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -772,6 +772,9 @@ class MultiParams: new_data.setdefault(key, []).append(value) self._data = new_data + def __repr__(self): + return "".format(self._data) + def __contains__(self, key): return key in self._data diff --git a/datasette/views/base.py b/datasette/views/base.py index 2402406a..315c96fe 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -95,6 +95,7 @@ class BaseView(AsgiView): **context, **{ "database_url": self.database_url, + "csrftoken": request.scope["csrftoken"], "database_color": self.database_color, "show_messages": lambda: self.ds._show_messages(request), "select_templates": [ diff --git a/setup.py b/setup.py index 93628266..c0316deb 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ setup( "uvicorn~=0.11", "aiofiles>=0.4,<0.6", "janus>=0.4,<0.6", + "asgi-csrf>=0.4", "PyYAML~=5.3", "mergedeep>=1.1.1,<1.4.0", "itsdangerous~=1.1", diff --git a/tests/fixtures.py b/tests/fixtures.py index 78a54c68..a64a8295 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,5 +1,5 @@ from datasette.app import Datasette -from datasette.utils import sqlite3 +from datasette.utils import sqlite3, MultiParams from asgiref.testing import ApplicationCommunicator from asgiref.sync import async_to_sync from http.cookies import SimpleCookie @@ -60,10 +60,35 @@ class TestClient: @async_to_sync async def post( - self, path, post_data=None, allow_redirects=True, redirect_count=0, cookies=None + self, + path, + post_data=None, + allow_redirects=True, + redirect_count=0, + content_type="application/x-www-form-urlencoded", + cookies=None, + csrftoken_from=None, ): + cookies = cookies or {} + post_data = post_data or {} + # Maybe fetch a csrftoken first + if csrftoken_from is not None: + if csrftoken_from is True: + csrftoken_from = path + token_response = await self._request(csrftoken_from) + # Check this had a Vary: Cookie header + assert "Cookie" == token_response.headers["vary"] + csrftoken = token_response.cookies["ds_csrftoken"] + cookies["ds_csrftoken"] = csrftoken + post_data["csrftoken"] = csrftoken return await self._request( - path, allow_redirects, redirect_count, "POST", cookies, post_data + path, + allow_redirects, + redirect_count, + "POST", + cookies, + post_data, + content_type, ) async def _request( @@ -74,6 +99,7 @@ class TestClient: method="GET", cookies=None, post_data=None, + content_type=None, ): query_string = b"" if "?" in path: @@ -84,6 +110,8 @@ class TestClient: else: raw_path = quote(path, safe="/:,").encode("latin-1") headers = [[b"host", b"localhost"]] + if content_type: + headers.append((b"content-type", content_type.encode("utf-8"))) if cookies: sc = SimpleCookie() for key, value in cookies.items(): @@ -111,7 +139,7 @@ class TestClient: start = await instance.receive_output(2) messages.append(start) assert start["type"] == "http.response.start" - headers = dict( + response_headers = MultiParams( [(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]] ) status = start["status"] @@ -124,7 +152,7 @@ class TestClient: body += message["body"] if not message.get("more_body"): break - response = TestResponse(status, headers, body) + response = TestResponse(status, response_headers, body) if allow_redirects and response.status in (301, 302): assert ( redirect_count < self.max_redirects diff --git a/tests/test_canned_write.py b/tests/test_canned_write.py index 692d726e..be838063 100644 --- a/tests/test_canned_write.py +++ b/tests/test_canned_write.py @@ -40,7 +40,7 @@ def canned_write_client(): def test_insert(canned_write_client): response = canned_write_client.post( - "/data/add_name", {"name": "Hello"}, allow_redirects=False + "/data/add_name", {"name": "Hello"}, allow_redirects=False, csrftoken_from=True, ) assert 302 == response.status assert "/data/add_name?success" == response.headers["Location"] @@ -52,7 +52,7 @@ def test_insert(canned_write_client): def test_custom_success_message(canned_write_client): response = canned_write_client.post( - "/data/delete_name", {"rowid": 1}, allow_redirects=False + "/data/delete_name", {"rowid": 1}, allow_redirects=False, csrftoken_from=True ) assert 302 == response.status messages = canned_write_client.ds.unsign( @@ -62,11 +62,12 @@ def test_custom_success_message(canned_write_client): def test_insert_error(canned_write_client): - canned_write_client.post("/data/add_name", {"name": "Hello"}) + canned_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True) response = canned_write_client.post( "/data/add_name_specify_id", {"rowid": 1, "name": "Should fail"}, allow_redirects=False, + csrftoken_from=True, ) assert 302 == response.status assert "/data/add_name_specify_id?error" == response.headers["Location"] @@ -82,6 +83,7 @@ def test_insert_error(canned_write_client): "/data/add_name_specify_id", {"rowid": 1, "name": "Should fail"}, allow_redirects=False, + csrftoken_from=True, ) assert [["ERROR", 3]] == canned_write_client.ds.unsign( response.cookies["ds_messages"], "messages" diff --git a/tests/test_utils.py b/tests/test_utils.py index a7968e54..cf714215 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -439,15 +439,18 @@ def test_call_with_supported_arguments(): utils.call_with_supported_arguments(foo, a=1) -@pytest.mark.parametrize("data,should_raise", [ - ([["foo", "bar"], ["foo", "baz"]], False), - ([("foo", "bar"), ("foo", "baz")], False), - ((["foo", "bar"], ["foo", "baz"]), False), - ([["foo", "bar"], ["foo", "baz", "bax"]], True), - ({"foo": ["bar", "baz"]}, False), - ({"foo": ("bar", "baz")}, False), - ({"foo": "bar"}, True), -]) +@pytest.mark.parametrize( + "data,should_raise", + [ + ([["foo", "bar"], ["foo", "baz"]], False), + ([("foo", "bar"), ("foo", "baz")], False), + ((["foo", "bar"], ["foo", "baz"]), False), + ([["foo", "bar"], ["foo", "baz", "bax"]], True), + ({"foo": ["bar", "baz"]}, False), + ({"foo": ("bar", "baz")}, False), + ({"foo": "bar"}, True), + ] +) def test_multi_params(data, should_raise): if should_raise: with pytest.raises(AssertionError):