diff --git a/datasette/app.py b/datasette/app.py
index f9bf91a8..54cf02f8 100644
--- a/datasette/app.py
+++ b/datasette/app.py
@@ -1,4 +1,5 @@
import asyncio
+import asgi_csrf
import collections
import datetime
import hashlib
@@ -884,7 +885,14 @@ class Datasette:
await database.table_counts(limit=60 * 60 * 1000)
asgi = AsgiLifespan(
- AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db
+ AsgiTracer(
+ asgi_csrf.asgi_csrf(
+ DatasetteRouter(self, routes),
+ signing_secret=self._secret,
+ cookie_name="ds_csrftoken",
+ )
+ ),
+ on_startup=setup_db,
)
for wrapper in pm.hook.asgi_wrapper(datasette=self):
asgi = wrapper(asgi)
diff --git a/datasette/templates/messages_debug.html b/datasette/templates/messages_debug.html
index b2e1bc7c..e83d2a2f 100644
--- a/datasette/templates/messages_debug.html
+++ b/datasette/templates/messages_debug.html
@@ -8,7 +8,7 @@
Set a message:
-
diff --git a/datasette/templates/query.html b/datasette/templates/query.html
index 52896e96..a7cb6647 100644
--- a/datasette/templates/query.html
+++ b/datasette/templates/query.html
@@ -52,6 +52,7 @@
{% endif %}
+ {% if canned_query %}{% endif %}
diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py
index 69e288e6..059db184 100644
--- a/datasette/utils/__init__.py
+++ b/datasette/utils/__init__.py
@@ -772,6 +772,9 @@ class MultiParams:
new_data.setdefault(key, []).append(value)
self._data = new_data
+ def __repr__(self):
+ return "".format(self._data)
+
def __contains__(self, key):
return key in self._data
diff --git a/datasette/views/base.py b/datasette/views/base.py
index 2402406a..315c96fe 100644
--- a/datasette/views/base.py
+++ b/datasette/views/base.py
@@ -95,6 +95,7 @@ class BaseView(AsgiView):
**context,
**{
"database_url": self.database_url,
+ "csrftoken": request.scope["csrftoken"],
"database_color": self.database_color,
"show_messages": lambda: self.ds._show_messages(request),
"select_templates": [
diff --git a/setup.py b/setup.py
index 93628266..c0316deb 100644
--- a/setup.py
+++ b/setup.py
@@ -53,6 +53,7 @@ setup(
"uvicorn~=0.11",
"aiofiles>=0.4,<0.6",
"janus>=0.4,<0.6",
+ "asgi-csrf>=0.4",
"PyYAML~=5.3",
"mergedeep>=1.1.1,<1.4.0",
"itsdangerous~=1.1",
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 78a54c68..a64a8295 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -1,5 +1,5 @@
from datasette.app import Datasette
-from datasette.utils import sqlite3
+from datasette.utils import sqlite3, MultiParams
from asgiref.testing import ApplicationCommunicator
from asgiref.sync import async_to_sync
from http.cookies import SimpleCookie
@@ -60,10 +60,35 @@ class TestClient:
@async_to_sync
async def post(
- self, path, post_data=None, allow_redirects=True, redirect_count=0, cookies=None
+ self,
+ path,
+ post_data=None,
+ allow_redirects=True,
+ redirect_count=0,
+ content_type="application/x-www-form-urlencoded",
+ cookies=None,
+ csrftoken_from=None,
):
+ cookies = cookies or {}
+ post_data = post_data or {}
+ # Maybe fetch a csrftoken first
+ if csrftoken_from is not None:
+ if csrftoken_from is True:
+ csrftoken_from = path
+ token_response = await self._request(csrftoken_from)
+ # Check this had a Vary: Cookie header
+ assert "Cookie" == token_response.headers["vary"]
+ csrftoken = token_response.cookies["ds_csrftoken"]
+ cookies["ds_csrftoken"] = csrftoken
+ post_data["csrftoken"] = csrftoken
return await self._request(
- path, allow_redirects, redirect_count, "POST", cookies, post_data
+ path,
+ allow_redirects,
+ redirect_count,
+ "POST",
+ cookies,
+ post_data,
+ content_type,
)
async def _request(
@@ -74,6 +99,7 @@ class TestClient:
method="GET",
cookies=None,
post_data=None,
+ content_type=None,
):
query_string = b""
if "?" in path:
@@ -84,6 +110,8 @@ class TestClient:
else:
raw_path = quote(path, safe="/:,").encode("latin-1")
headers = [[b"host", b"localhost"]]
+ if content_type:
+ headers.append((b"content-type", content_type.encode("utf-8")))
if cookies:
sc = SimpleCookie()
for key, value in cookies.items():
@@ -111,7 +139,7 @@ class TestClient:
start = await instance.receive_output(2)
messages.append(start)
assert start["type"] == "http.response.start"
- headers = dict(
+ response_headers = MultiParams(
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
)
status = start["status"]
@@ -124,7 +152,7 @@ class TestClient:
body += message["body"]
if not message.get("more_body"):
break
- response = TestResponse(status, headers, body)
+ response = TestResponse(status, response_headers, body)
if allow_redirects and response.status in (301, 302):
assert (
redirect_count < self.max_redirects
diff --git a/tests/test_canned_write.py b/tests/test_canned_write.py
index 692d726e..be838063 100644
--- a/tests/test_canned_write.py
+++ b/tests/test_canned_write.py
@@ -40,7 +40,7 @@ def canned_write_client():
def test_insert(canned_write_client):
response = canned_write_client.post(
- "/data/add_name", {"name": "Hello"}, allow_redirects=False
+ "/data/add_name", {"name": "Hello"}, allow_redirects=False, csrftoken_from=True,
)
assert 302 == response.status
assert "/data/add_name?success" == response.headers["Location"]
@@ -52,7 +52,7 @@ def test_insert(canned_write_client):
def test_custom_success_message(canned_write_client):
response = canned_write_client.post(
- "/data/delete_name", {"rowid": 1}, allow_redirects=False
+ "/data/delete_name", {"rowid": 1}, allow_redirects=False, csrftoken_from=True
)
assert 302 == response.status
messages = canned_write_client.ds.unsign(
@@ -62,11 +62,12 @@ def test_custom_success_message(canned_write_client):
def test_insert_error(canned_write_client):
- canned_write_client.post("/data/add_name", {"name": "Hello"})
+ canned_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True)
response = canned_write_client.post(
"/data/add_name_specify_id",
{"rowid": 1, "name": "Should fail"},
allow_redirects=False,
+ csrftoken_from=True,
)
assert 302 == response.status
assert "/data/add_name_specify_id?error" == response.headers["Location"]
@@ -82,6 +83,7 @@ def test_insert_error(canned_write_client):
"/data/add_name_specify_id",
{"rowid": 1, "name": "Should fail"},
allow_redirects=False,
+ csrftoken_from=True,
)
assert [["ERROR", 3]] == canned_write_client.ds.unsign(
response.cookies["ds_messages"], "messages"
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a7968e54..cf714215 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -439,15 +439,18 @@ def test_call_with_supported_arguments():
utils.call_with_supported_arguments(foo, a=1)
-@pytest.mark.parametrize("data,should_raise", [
- ([["foo", "bar"], ["foo", "baz"]], False),
- ([("foo", "bar"), ("foo", "baz")], False),
- ((["foo", "bar"], ["foo", "baz"]), False),
- ([["foo", "bar"], ["foo", "baz", "bax"]], True),
- ({"foo": ["bar", "baz"]}, False),
- ({"foo": ("bar", "baz")}, False),
- ({"foo": "bar"}, True),
-])
+@pytest.mark.parametrize(
+ "data,should_raise",
+ [
+ ([["foo", "bar"], ["foo", "baz"]], False),
+ ([("foo", "bar"), ("foo", "baz")], False),
+ ((["foo", "bar"], ["foo", "baz"]), False),
+ ([["foo", "bar"], ["foo", "baz", "bax"]], True),
+ ({"foo": ["bar", "baz"]}, False),
+ ({"foo": ("bar", "baz")}, False),
+ ({"foo": "bar"}, True),
+ ]
+)
def test_multi_params(data, should_raise):
if should_raise:
with pytest.raises(AssertionError):