mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
CSRF protection (#798)
Closes #793. * Rename RequestParameters to MultiParams, refs #799 * Allow tuples as well as lists in MultiParams, refs #799 * Use csrftokens when running tests, refs #799 * Use new csrftoken() function, refs https://github.com/simonw/asgi-csrf/issues/7 * Check for Vary: Cookie hedaer, refs https://github.com/simonw/asgi-csrf/issues/8
This commit is contained in:
parent
d96ac1d52c
commit
84a9c4ff75
9 changed files with 67 additions and 19 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import asgi_csrf
|
||||||
import collections
|
import collections
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
@ -884,7 +885,14 @@ class Datasette:
|
||||||
await database.table_counts(limit=60 * 60 * 1000)
|
await database.table_counts(limit=60 * 60 * 1000)
|
||||||
|
|
||||||
asgi = AsgiLifespan(
|
asgi = AsgiLifespan(
|
||||||
AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db
|
AsgiTracer(
|
||||||
|
asgi_csrf.asgi_csrf(
|
||||||
|
DatasetteRouter(self, routes),
|
||||||
|
signing_secret=self._secret,
|
||||||
|
cookie_name="ds_csrftoken",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
on_startup=setup_db,
|
||||||
)
|
)
|
||||||
for wrapper in pm.hook.asgi_wrapper(datasette=self):
|
for wrapper in pm.hook.asgi_wrapper(datasette=self):
|
||||||
asgi = wrapper(asgi)
|
asgi = wrapper(asgi)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
<p>Set a message:</p>
|
<p>Set a message:</p>
|
||||||
|
|
||||||
<form action="/-/messages" method="POST">
|
<form action="/-/messages" method="post">
|
||||||
<div>
|
<div>
|
||||||
<input type="text" name="message" style="width: 40%">
|
<input type="text" name="message" style="width: 40%">
|
||||||
<div class="select-wrapper">
|
<div class="select-wrapper">
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
<option>all</option>
|
<option>all</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
|
<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">
|
||||||
<input type="submit" value="Add message">
|
<input type="submit" value="Add message">
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<p>
|
<p>
|
||||||
<button id="sql-format" type="button" hidden>Format SQL</button>
|
<button id="sql-format" type="button" hidden>Format SQL</button>
|
||||||
|
{% if canned_query %}<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">{% endif %}
|
||||||
<input type="submit" value="Run SQL">
|
<input type="submit" value="Run SQL">
|
||||||
</p>
|
</p>
|
||||||
</form>
|
</form>
|
||||||
|
|
|
||||||
|
|
@ -772,6 +772,9 @@ class MultiParams:
|
||||||
new_data.setdefault(key, []).append(value)
|
new_data.setdefault(key, []).append(value)
|
||||||
self._data = new_data
|
self._data = new_data
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<MultiParams: {}>".format(self._data)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return key in self._data
|
return key in self._data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,7 @@ class BaseView(AsgiView):
|
||||||
**context,
|
**context,
|
||||||
**{
|
**{
|
||||||
"database_url": self.database_url,
|
"database_url": self.database_url,
|
||||||
|
"csrftoken": request.scope["csrftoken"],
|
||||||
"database_color": self.database_color,
|
"database_color": self.database_color,
|
||||||
"show_messages": lambda: self.ds._show_messages(request),
|
"show_messages": lambda: self.ds._show_messages(request),
|
||||||
"select_templates": [
|
"select_templates": [
|
||||||
|
|
|
||||||
1
setup.py
1
setup.py
|
|
@ -53,6 +53,7 @@ setup(
|
||||||
"uvicorn~=0.11",
|
"uvicorn~=0.11",
|
||||||
"aiofiles>=0.4,<0.6",
|
"aiofiles>=0.4,<0.6",
|
||||||
"janus>=0.4,<0.6",
|
"janus>=0.4,<0.6",
|
||||||
|
"asgi-csrf>=0.4",
|
||||||
"PyYAML~=5.3",
|
"PyYAML~=5.3",
|
||||||
"mergedeep>=1.1.1,<1.4.0",
|
"mergedeep>=1.1.1,<1.4.0",
|
||||||
"itsdangerous~=1.1",
|
"itsdangerous~=1.1",
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from datasette.app import Datasette
|
from datasette.app import Datasette
|
||||||
from datasette.utils import sqlite3
|
from datasette.utils import sqlite3, MultiParams
|
||||||
from asgiref.testing import ApplicationCommunicator
|
from asgiref.testing import ApplicationCommunicator
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from http.cookies import SimpleCookie
|
from http.cookies import SimpleCookie
|
||||||
|
|
@ -60,10 +60,35 @@ class TestClient:
|
||||||
|
|
||||||
@async_to_sync
|
@async_to_sync
|
||||||
async def post(
|
async def post(
|
||||||
self, path, post_data=None, allow_redirects=True, redirect_count=0, cookies=None
|
self,
|
||||||
|
path,
|
||||||
|
post_data=None,
|
||||||
|
allow_redirects=True,
|
||||||
|
redirect_count=0,
|
||||||
|
content_type="application/x-www-form-urlencoded",
|
||||||
|
cookies=None,
|
||||||
|
csrftoken_from=None,
|
||||||
):
|
):
|
||||||
|
cookies = cookies or {}
|
||||||
|
post_data = post_data or {}
|
||||||
|
# Maybe fetch a csrftoken first
|
||||||
|
if csrftoken_from is not None:
|
||||||
|
if csrftoken_from is True:
|
||||||
|
csrftoken_from = path
|
||||||
|
token_response = await self._request(csrftoken_from)
|
||||||
|
# Check this had a Vary: Cookie header
|
||||||
|
assert "Cookie" == token_response.headers["vary"]
|
||||||
|
csrftoken = token_response.cookies["ds_csrftoken"]
|
||||||
|
cookies["ds_csrftoken"] = csrftoken
|
||||||
|
post_data["csrftoken"] = csrftoken
|
||||||
return await self._request(
|
return await self._request(
|
||||||
path, allow_redirects, redirect_count, "POST", cookies, post_data
|
path,
|
||||||
|
allow_redirects,
|
||||||
|
redirect_count,
|
||||||
|
"POST",
|
||||||
|
cookies,
|
||||||
|
post_data,
|
||||||
|
content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _request(
|
async def _request(
|
||||||
|
|
@ -74,6 +99,7 @@ class TestClient:
|
||||||
method="GET",
|
method="GET",
|
||||||
cookies=None,
|
cookies=None,
|
||||||
post_data=None,
|
post_data=None,
|
||||||
|
content_type=None,
|
||||||
):
|
):
|
||||||
query_string = b""
|
query_string = b""
|
||||||
if "?" in path:
|
if "?" in path:
|
||||||
|
|
@ -84,6 +110,8 @@ class TestClient:
|
||||||
else:
|
else:
|
||||||
raw_path = quote(path, safe="/:,").encode("latin-1")
|
raw_path = quote(path, safe="/:,").encode("latin-1")
|
||||||
headers = [[b"host", b"localhost"]]
|
headers = [[b"host", b"localhost"]]
|
||||||
|
if content_type:
|
||||||
|
headers.append((b"content-type", content_type.encode("utf-8")))
|
||||||
if cookies:
|
if cookies:
|
||||||
sc = SimpleCookie()
|
sc = SimpleCookie()
|
||||||
for key, value in cookies.items():
|
for key, value in cookies.items():
|
||||||
|
|
@ -111,7 +139,7 @@ class TestClient:
|
||||||
start = await instance.receive_output(2)
|
start = await instance.receive_output(2)
|
||||||
messages.append(start)
|
messages.append(start)
|
||||||
assert start["type"] == "http.response.start"
|
assert start["type"] == "http.response.start"
|
||||||
headers = dict(
|
response_headers = MultiParams(
|
||||||
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
|
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
|
||||||
)
|
)
|
||||||
status = start["status"]
|
status = start["status"]
|
||||||
|
|
@ -124,7 +152,7 @@ class TestClient:
|
||||||
body += message["body"]
|
body += message["body"]
|
||||||
if not message.get("more_body"):
|
if not message.get("more_body"):
|
||||||
break
|
break
|
||||||
response = TestResponse(status, headers, body)
|
response = TestResponse(status, response_headers, body)
|
||||||
if allow_redirects and response.status in (301, 302):
|
if allow_redirects and response.status in (301, 302):
|
||||||
assert (
|
assert (
|
||||||
redirect_count < self.max_redirects
|
redirect_count < self.max_redirects
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ def canned_write_client():
|
||||||
|
|
||||||
def test_insert(canned_write_client):
|
def test_insert(canned_write_client):
|
||||||
response = canned_write_client.post(
|
response = canned_write_client.post(
|
||||||
"/data/add_name", {"name": "Hello"}, allow_redirects=False
|
"/data/add_name", {"name": "Hello"}, allow_redirects=False, csrftoken_from=True,
|
||||||
)
|
)
|
||||||
assert 302 == response.status
|
assert 302 == response.status
|
||||||
assert "/data/add_name?success" == response.headers["Location"]
|
assert "/data/add_name?success" == response.headers["Location"]
|
||||||
|
|
@ -52,7 +52,7 @@ def test_insert(canned_write_client):
|
||||||
|
|
||||||
def test_custom_success_message(canned_write_client):
|
def test_custom_success_message(canned_write_client):
|
||||||
response = canned_write_client.post(
|
response = canned_write_client.post(
|
||||||
"/data/delete_name", {"rowid": 1}, allow_redirects=False
|
"/data/delete_name", {"rowid": 1}, allow_redirects=False, csrftoken_from=True
|
||||||
)
|
)
|
||||||
assert 302 == response.status
|
assert 302 == response.status
|
||||||
messages = canned_write_client.ds.unsign(
|
messages = canned_write_client.ds.unsign(
|
||||||
|
|
@ -62,11 +62,12 @@ def test_custom_success_message(canned_write_client):
|
||||||
|
|
||||||
|
|
||||||
def test_insert_error(canned_write_client):
|
def test_insert_error(canned_write_client):
|
||||||
canned_write_client.post("/data/add_name", {"name": "Hello"})
|
canned_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True)
|
||||||
response = canned_write_client.post(
|
response = canned_write_client.post(
|
||||||
"/data/add_name_specify_id",
|
"/data/add_name_specify_id",
|
||||||
{"rowid": 1, "name": "Should fail"},
|
{"rowid": 1, "name": "Should fail"},
|
||||||
allow_redirects=False,
|
allow_redirects=False,
|
||||||
|
csrftoken_from=True,
|
||||||
)
|
)
|
||||||
assert 302 == response.status
|
assert 302 == response.status
|
||||||
assert "/data/add_name_specify_id?error" == response.headers["Location"]
|
assert "/data/add_name_specify_id?error" == response.headers["Location"]
|
||||||
|
|
@ -82,6 +83,7 @@ def test_insert_error(canned_write_client):
|
||||||
"/data/add_name_specify_id",
|
"/data/add_name_specify_id",
|
||||||
{"rowid": 1, "name": "Should fail"},
|
{"rowid": 1, "name": "Should fail"},
|
||||||
allow_redirects=False,
|
allow_redirects=False,
|
||||||
|
csrftoken_from=True,
|
||||||
)
|
)
|
||||||
assert [["ERROR", 3]] == canned_write_client.ds.unsign(
|
assert [["ERROR", 3]] == canned_write_client.ds.unsign(
|
||||||
response.cookies["ds_messages"], "messages"
|
response.cookies["ds_messages"], "messages"
|
||||||
|
|
|
||||||
|
|
@ -439,7 +439,9 @@ def test_call_with_supported_arguments():
|
||||||
utils.call_with_supported_arguments(foo, a=1)
|
utils.call_with_supported_arguments(foo, a=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("data,should_raise", [
|
@pytest.mark.parametrize(
|
||||||
|
"data,should_raise",
|
||||||
|
[
|
||||||
([["foo", "bar"], ["foo", "baz"]], False),
|
([["foo", "bar"], ["foo", "baz"]], False),
|
||||||
([("foo", "bar"), ("foo", "baz")], False),
|
([("foo", "bar"), ("foo", "baz")], False),
|
||||||
((["foo", "bar"], ["foo", "baz"]), False),
|
((["foo", "bar"], ["foo", "baz"]), False),
|
||||||
|
|
@ -447,7 +449,8 @@ def test_call_with_supported_arguments():
|
||||||
({"foo": ["bar", "baz"]}, False),
|
({"foo": ["bar", "baz"]}, False),
|
||||||
({"foo": ("bar", "baz")}, False),
|
({"foo": ("bar", "baz")}, False),
|
||||||
({"foo": "bar"}, True),
|
({"foo": "bar"}, True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def test_multi_params(data, should_raise):
|
def test_multi_params(data, should_raise):
|
||||||
if should_raise:
|
if should_raise:
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue