mirror of
https://github.com/simonw/datasette.git
synced 2026-05-27 20:36:17 +02:00
334 lines
11 KiB
Python
334 lines
11 KiB
Python
"""
|
|
Tests for the header-based CSRF (Cross-Origin) protection middleware.
|
|
|
|
Datasette uses the Sec-Fetch-Site + Origin header approach described in
|
|
Filippo Valsorda's article (https://words.filippo.io/csrf/) and implemented
|
|
in Go 1.25's http.CrossOriginProtection. This replaces the previous
|
|
token-based asgi-csrf mechanism.
|
|
"""
|
|
|
|
import pluggy
|
|
import pytest
|
|
|
|
from datasette import hookimpl
|
|
from datasette.csrf import CrossOriginProtectionMiddleware, _install_legacy_csrftoken
|
|
|
|
|
|
async def _post(bare_ds, **kwargs):
|
|
kwargs.setdefault("data", {"message": "hello", "message_class": "info"})
|
|
return await bare_ds.client.post("/-/messages", **kwargs)
|
|
|
|
|
|
async def _run_middleware(scope):
|
|
"""
|
|
Run CrossOriginProtectionMiddleware against a scope and return
|
|
("allowed",) if the inner app was called, or ("blocked", status)
|
|
if the middleware sent a response itself.
|
|
"""
|
|
|
|
class FakeDs:
|
|
async def render_template(self, name, ctx):
|
|
return "BLOCKED"
|
|
|
|
inner_called = []
|
|
|
|
async def app(scope, receive, send):
|
|
inner_called.append(True)
|
|
|
|
sent = []
|
|
|
|
async def send(msg):
|
|
sent.append(msg)
|
|
|
|
mw = CrossOriginProtectionMiddleware(app, FakeDs())
|
|
await mw(scope, None, send)
|
|
if inner_called:
|
|
return ("allowed",)
|
|
start = [m for m in sent if m["type"] == "http.response.start"][0]
|
|
return ("blocked", start["status"])
|
|
|
|
|
|
def _http_scope(headers, method="POST"):
|
|
return {
|
|
"type": "http",
|
|
"method": method,
|
|
"headers": [(k.encode(), v.encode()) for k, v in headers.items()],
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("method", ["GET", "HEAD", "OPTIONS"])
|
|
async def test_safe_methods_always_pass(bare_ds, method):
|
|
# Safe methods bypass CSRF entirely, even with hostile headers
|
|
response = await bare_ds.client.request(
|
|
method,
|
|
"/-/messages",
|
|
headers={"sec-fetch-site": "cross-site", "origin": "http://evil.example"},
|
|
)
|
|
assert response.status_code != 403 or "origin" not in response.text.lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("sec_fetch_site", ["same-origin", "none"])
|
|
async def test_post_with_trusted_sec_fetch_site_allowed(bare_ds, sec_fetch_site):
|
|
# "same-origin" = first-party; "none" = user-initiated direct navigation
|
|
response = await _post(bare_ds, headers={"sec-fetch-site": sec_fetch_site})
|
|
assert response.status_code != 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("sec_fetch_site", ["cross-site", "same-site", "cross-origin"])
|
|
async def test_post_with_untrusted_sec_fetch_site_blocked(bare_ds, sec_fetch_site):
|
|
# same-site is blocked too: different subdomains must not bypass CSRF
|
|
response = await _post(
|
|
bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": sec_fetch_site}
|
|
)
|
|
assert response.status_code == 403
|
|
assert response.headers["content-type"].startswith("text/html")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_with_no_browser_headers_allowed(bare_ds):
|
|
# curl / requests / server-to-server: no Sec-Fetch-Site, no Origin.
|
|
# CSRF is browser-specific so these pass through.
|
|
response = await _post(bare_ds)
|
|
assert response.status_code != 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_with_matching_origin_allowed(bare_ds):
|
|
# Fallback for older browsers without Sec-Fetch-Site: Origin must match Host
|
|
response = await _post(bare_ds, headers={"origin": "http://localhost"})
|
|
assert response.status_code != 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_with_mismatched_origin_blocked(bare_ds):
|
|
response = await _post(
|
|
bare_ds, data={"message": "hi"}, headers={"origin": "http://evil.example.com"}
|
|
)
|
|
assert response.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_csrf_error_page_renders(bare_ds):
|
|
response = await _post(
|
|
bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"}
|
|
)
|
|
assert response.status_code == 403
|
|
assert "origin" in response.text.lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_csrf_error_page_title_has_no_typo(bare_ds):
|
|
response = await _post(
|
|
bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"}
|
|
)
|
|
assert "<title>CSRF check failed</title>" in response.text
|
|
assert "CSRF check failed)" not in response.text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("scope_type", ["websocket", "lifespan"])
|
|
async def test_non_http_scope_passes_through(scope_type):
|
|
called = []
|
|
|
|
async def app(scope, receive, send):
|
|
called.append(scope["type"])
|
|
|
|
mw = CrossOriginProtectionMiddleware(app, datasette=None)
|
|
await mw({"type": scope_type}, None, None)
|
|
assert called == [scope_type]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"label,headers,expected",
|
|
[
|
|
(
|
|
"plain cross-site blocked",
|
|
{"sec-fetch-site": "cross-site", "host": "example.com"},
|
|
("blocked", 403),
|
|
),
|
|
(
|
|
"basic auth does not bypass",
|
|
{
|
|
"sec-fetch-site": "cross-site",
|
|
"host": "example.com",
|
|
"authorization": "Basic dXNlcjpwYXNz",
|
|
},
|
|
("blocked", 403),
|
|
),
|
|
(
|
|
"bearer auth bypasses",
|
|
{
|
|
"sec-fetch-site": "cross-site",
|
|
"origin": "https://evil.example",
|
|
"host": "example.com",
|
|
"authorization": "Bearer dstok_abc",
|
|
},
|
|
("allowed",),
|
|
),
|
|
(
|
|
"bearer scheme case-insensitive",
|
|
{
|
|
"sec-fetch-site": "cross-site",
|
|
"host": "example.com",
|
|
"authorization": "bearer dstok_abc",
|
|
},
|
|
("allowed",),
|
|
),
|
|
(
|
|
"non-browser (no Sec-Fetch-Site, no Origin) allowed",
|
|
{"host": "example.com"},
|
|
("allowed",),
|
|
),
|
|
],
|
|
)
|
|
async def test_middleware_unit(label, headers, expected):
|
|
assert await _run_middleware(_http_scope(headers)) == expected
|
|
|
|
|
|
def test_legacy_csrftoken_scope_value_nonempty(app_client):
|
|
# GET /post/ calls request.scope["csrftoken"]() - must not 500
|
|
response = app_client.get("/post/")
|
|
assert response.status == 200
|
|
assert response.text.strip() != ""
|
|
assert len(response.text.strip()) >= 20
|
|
|
|
|
|
def test_legacy_csrftoken_no_ds_csrftoken_cookie(app_client):
|
|
response = app_client.get("/post/")
|
|
assert "ds_csrftoken" not in response.cookies
|
|
|
|
|
|
def test_legacy_csrftoken_varies_across_requests(app_client):
|
|
r1 = app_client.get("/post/").text.strip()
|
|
r2 = app_client.get("/post/").text.strip()
|
|
assert r1 != r2
|
|
|
|
|
|
def test_legacy_csrftoken_stable_within_request():
|
|
# Two calls in the same request return the same value
|
|
scope = {}
|
|
_install_legacy_csrftoken(scope)
|
|
assert scope["csrftoken"]() == scope["csrftoken"]()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cross_site_post_blocked_even_with_ds_csrftoken_cookie(bare_ds):
|
|
# A stale ds_csrftoken cookie + csrftoken body field must NOT bypass
|
|
# the header-based CSRF check.
|
|
response = await _post(
|
|
bare_ds,
|
|
data={"message": "hi", "message_class": "info", "csrftoken": "abc"},
|
|
headers={"sec-fetch-site": "cross-site"},
|
|
cookies={"ds_csrftoken": "abc"},
|
|
)
|
|
assert response.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bearer_invalid_token_not_csrf_error(bare_ds):
|
|
# Cross-site POST with bogus bearer must pass CSRF and be rejected
|
|
# by auth/permission handling, not by the CSRF middleware.
|
|
response = await _post(
|
|
bare_ds,
|
|
headers={
|
|
"sec-fetch-site": "cross-site",
|
|
"authorization": "Bearer totally-invalid-token",
|
|
},
|
|
)
|
|
if response.status_code == 403:
|
|
assert "origin" not in response.text.lower()
|
|
assert "sec-fetch-site" not in response.text.lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cross_site_post_without_auth_still_blocked(bare_ds):
|
|
response = await _post(
|
|
bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"}
|
|
)
|
|
assert response.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bearer_with_cookie_does_not_bypass():
|
|
# Bearer + Cookie => ambient cookie auth is in play, not exempt.
|
|
scope = _http_scope(
|
|
{
|
|
"sec-fetch-site": "cross-site",
|
|
"host": "example.com",
|
|
"authorization": "Bearer dstok_abc",
|
|
"cookie": "ds_actor=anything",
|
|
}
|
|
)
|
|
assert await _run_middleware(scope) == ("blocked", 403)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_origin_scheme_must_match():
|
|
# http Origin against an https request must be blocked even when host matches.
|
|
scope = _http_scope({"origin": "http://example.com", "host": "example.com"})
|
|
scope["scheme"] = "https"
|
|
assert await _run_middleware(scope) == ("blocked", 403)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_origin_port_must_match():
|
|
scope = _http_scope({"origin": "http://example.com:8001", "host": "example.com"})
|
|
scope["scheme"] = "http"
|
|
assert await _run_middleware(scope) == ("blocked", 403)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_origin_default_port_normalized():
|
|
# http://example.com:80 == http://example.com
|
|
scope = _http_scope({"origin": "http://example.com:80", "host": "example.com"})
|
|
scope["scheme"] = "http"
|
|
assert await _run_middleware(scope) == ("allowed",)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"headers",
|
|
[
|
|
{"authorization": " ", "host": "example.com", "origin": "http://evil"},
|
|
{"origin": "http://example.com:notaport", "host": "example.com"},
|
|
{"origin": "not-a-url", "host": "example.com"},
|
|
],
|
|
)
|
|
async def test_malformed_headers_do_not_500(headers):
|
|
# Should be a clean 403, not an unhandled exception.
|
|
result = await _run_middleware(_http_scope(headers))
|
|
assert result[0] == "blocked"
|
|
assert result[1] == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_uppercase_header_names_normalized():
|
|
# ASGI servers should lowercase, but middleware normalizes defensively.
|
|
scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"headers": [(b"Sec-Fetch-Site", b"same-origin")],
|
|
}
|
|
assert await _run_middleware(scope) == ("allowed",)
|
|
|
|
|
|
def test_legacy_skip_csrf_hookimpl_does_not_break_loading():
|
|
# Plugins that still define skip_csrf must load cleanly - pluggy ignores
|
|
# unknown hook implementations - even though the hook is no longer
|
|
# consulted by core. Use a throwaway PluginManager so that registering
|
|
# this hookimpl does not leak a _HookCaller onto the real datasette.pm.
|
|
class LegacyPlugin:
|
|
__name__ = "legacy-skip-csrf-plugin"
|
|
|
|
@hookimpl
|
|
def skip_csrf(self, datasette, scope):
|
|
return True
|
|
|
|
throwaway = pluggy.PluginManager("datasette")
|
|
plugin = LegacyPlugin()
|
|
throwaway.register(plugin, name=LegacyPlugin.__name__)
|
|
assert throwaway.is_registered(plugin)
|