Normalize headers in CSRF checks, refs #2689

This commit is contained in:
Simon Willison 2026-04-14 19:24:31 -07:00
commit a973e3ffa1
2 changed files with 110 additions and 11 deletions

View file

@ -16,6 +16,38 @@ from .utils.asgi import asgi_send
SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443}
def _normalize_headers(raw_headers):
"""Lowercase header names; for duplicates, last value wins."""
result = {}
for name, value in raw_headers:
if isinstance(name, str):
name = name.encode("latin-1")
if isinstance(value, str):
value = value.encode("latin-1")
result[name.lower()] = value
return result
def _origin_tuple(value):
"""
Parse an origin-like string into ``(scheme, host, port)`` with default
ports filled in. Raises ``ValueError`` for malformed input.
"""
parsed = urllib.parse.urlsplit(value)
scheme = (parsed.scheme or "").lower()
host = (parsed.hostname or "").lower()
if not scheme or not host:
raise ValueError("missing scheme or host in {!r}".format(value))
port = parsed.port # may raise ValueError on bad ports
if port is None:
port = DEFAULT_PORTS.get(scheme)
if port is None:
raise ValueError("unknown default port for scheme {!r}".format(scheme))
return scheme, host, port
def _install_legacy_csrftoken(scope):
"""
@ -61,19 +93,19 @@ class CrossOriginProtectionMiddleware:
await self.app(scope, receive, send)
return
headers = dict(scope.get("headers") or [])
headers = _normalize_headers(scope.get("headers") or [])
authorization = headers.get(b"authorization", b"").decode("latin-1")
cookie_header = headers.get(b"cookie")
# Bearer-token requests are not ambient browser credentials, so they
# are not CSRF-vulnerable. Narrowly exempt them from the header check
# before evaluating Sec-Fetch-Site / Origin. Only "Bearer" is exempt;
# schemes like Basic or Digest can be browser-managed and ambient.
authorization = headers.get(b"authorization", b"").decode("latin-1")
cookie_header = headers.get(b"cookie")
# If the request also carries a Cookie header, ambient cookie auth
# could be in play, so do NOT treat it as exempt.
if authorization and not cookie_header:
scheme = authorization.split(None, 1)[0].lower()
if scheme == "bearer":
parts = authorization.split(None, 1)
if parts and parts[0].lower() == "bearer":
await self.app(scope, receive, send)
return
@ -104,12 +136,20 @@ class CrossOriginProtectionMiddleware:
await self.app(scope, receive, send)
return
# Fallback for older browsers: Origin host must match Host header
parsed = urllib.parse.urlparse(origin)
origin_host = parsed.hostname or ""
if parsed.port:
origin_host = "{}:{}".format(origin_host, parsed.port)
if origin_host == host:
# Fallback for older browsers: Origin must match the request's own
# scheme + host + port. Compare full origin tuples, not host alone.
request_scheme = self._request_scheme(scope)
try:
origin_tuple = _origin_tuple(origin)
expected_tuple = _origin_tuple("{}://{}".format(request_scheme, host))
except ValueError:
await self._forbid(
send,
"Malformed Origin {!r} or Host {!r}".format(origin, host),
)
return
if origin_tuple == expected_tuple:
await self.app(scope, receive, send)
return
@ -118,6 +158,15 @@ class CrossOriginProtectionMiddleware:
"Origin {!r} does not match Host {!r}".format(origin, host),
)
def _request_scheme(self, scope):
if self.datasette is not None:
try:
if self.datasette.setting("force_https_urls"):
return "https"
except Exception:
pass
return scope.get("scheme") or "http"
async def _forbid(self, send, reason):
await asgi_send(
send,

View file

@ -266,6 +266,56 @@ async def test_bearer_with_cookie_does_not_bypass():
assert await _run_middleware(scope) == ("blocked", 403)
@pytest.mark.asyncio
async def test_origin_scheme_must_match():
# http Origin against an https request must be blocked even when host matches.
scope = _http_scope({"origin": "http://example.com", "host": "example.com"})
scope["scheme"] = "https"
assert await _run_middleware(scope) == ("blocked", 403)
@pytest.mark.asyncio
async def test_origin_port_must_match():
scope = _http_scope({"origin": "http://example.com:8001", "host": "example.com"})
scope["scheme"] = "http"
assert await _run_middleware(scope) == ("blocked", 403)
@pytest.mark.asyncio
async def test_origin_default_port_normalized():
# http://example.com:80 == http://example.com
scope = _http_scope({"origin": "http://example.com:80", "host": "example.com"})
scope["scheme"] = "http"
assert await _run_middleware(scope) == ("allowed",)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"headers",
[
{"authorization": " ", "host": "example.com", "origin": "http://evil"},
{"origin": "http://example.com:notaport", "host": "example.com"},
{"origin": "not-a-url", "host": "example.com"},
],
)
async def test_malformed_headers_do_not_500(headers):
# Should be a clean 403, not an unhandled exception.
result = await _run_middleware(_http_scope(headers))
assert result[0] == "blocked"
assert result[1] == 403
@pytest.mark.asyncio
async def test_uppercase_header_names_normalized():
# ASGI servers should lowercase, but middleware normalizes defensively.
scope = {
"type": "http",
"method": "POST",
"headers": [(b"Sec-Fetch-Site", b"same-origin")],
}
assert await _run_middleware(scope) == ("allowed",)
def test_legacy_skip_csrf_hookimpl_does_not_break_loading():
# Plugins that still define skip_csrf must load cleanly - pluggy ignores
# unknown hook implementations - even though the hook is no longer