From a973e3ffa119c2563c7fb7cca0feef4797eb5b8a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 14 Apr 2026 19:24:31 -0700 Subject: [PATCH] Normalize headers in CSRF checks, refs #2689 --- datasette/csrf.py | 71 +++++++++++++++++++++++++++++------ tests/test_csrf_middleware.py | 50 ++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 11 deletions(-) diff --git a/datasette/csrf.py b/datasette/csrf.py index 845c8fb4..df239aee 100644 --- a/datasette/csrf.py +++ b/datasette/csrf.py @@ -16,6 +16,38 @@ from .utils.asgi import asgi_send SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) +DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443} + + +def _normalize_headers(raw_headers): + """Lowercase header names; for duplicates, last value wins.""" + result = {} + for name, value in raw_headers: + if isinstance(name, str): + name = name.encode("latin-1") + if isinstance(value, str): + value = value.encode("latin-1") + result[name.lower()] = value + return result + + +def _origin_tuple(value): + """ + Parse an origin-like string into ``(scheme, host, port)`` with default + ports filled in. Raises ``ValueError`` for malformed input. + """ + parsed = urllib.parse.urlsplit(value) + scheme = (parsed.scheme or "").lower() + host = (parsed.hostname or "").lower() + if not scheme or not host: + raise ValueError("missing scheme or host in {!r}".format(value)) + port = parsed.port # may raise ValueError on bad ports + if port is None: + port = DEFAULT_PORTS.get(scheme) + if port is None: + raise ValueError("unknown default port for scheme {!r}".format(scheme)) + return scheme, host, port + def _install_legacy_csrftoken(scope): """ @@ -61,19 +93,19 @@ class CrossOriginProtectionMiddleware: await self.app(scope, receive, send) return - headers = dict(scope.get("headers") or []) + headers = _normalize_headers(scope.get("headers") or []) + authorization = headers.get(b"authorization", b"").decode("latin-1") + cookie_header = headers.get(b"cookie") # Bearer-token requests are not ambient browser credentials, so they # are not CSRF-vulnerable. Narrowly exempt them from the header check # before evaluating Sec-Fetch-Site / Origin. Only "Bearer" is exempt; # schemes like Basic or Digest can be browser-managed and ambient. - authorization = headers.get(b"authorization", b"").decode("latin-1") - cookie_header = headers.get(b"cookie") # If the request also carries a Cookie header, ambient cookie auth # could be in play, so do NOT treat it as exempt. if authorization and not cookie_header: - scheme = authorization.split(None, 1)[0].lower() - if scheme == "bearer": + parts = authorization.split(None, 1) + if parts and parts[0].lower() == "bearer": await self.app(scope, receive, send) return @@ -104,12 +136,20 @@ class CrossOriginProtectionMiddleware: await self.app(scope, receive, send) return - # Fallback for older browsers: Origin host must match Host header - parsed = urllib.parse.urlparse(origin) - origin_host = parsed.hostname or "" - if parsed.port: - origin_host = "{}:{}".format(origin_host, parsed.port) - if origin_host == host: + # Fallback for older browsers: Origin must match the request's own + # scheme + host + port. Compare full origin tuples, not host alone. + request_scheme = self._request_scheme(scope) + try: + origin_tuple = _origin_tuple(origin) + expected_tuple = _origin_tuple("{}://{}".format(request_scheme, host)) + except ValueError: + await self._forbid( + send, + "Malformed Origin {!r} or Host {!r}".format(origin, host), + ) + return + + if origin_tuple == expected_tuple: await self.app(scope, receive, send) return @@ -118,6 +158,15 @@ class CrossOriginProtectionMiddleware: "Origin {!r} does not match Host {!r}".format(origin, host), ) + def _request_scheme(self, scope): + if self.datasette is not None: + try: + if self.datasette.setting("force_https_urls"): + return "https" + except Exception: + pass + return scope.get("scheme") or "http" + async def _forbid(self, send, reason): await asgi_send( send, diff --git a/tests/test_csrf_middleware.py b/tests/test_csrf_middleware.py index 820df1e7..2fcfb216 100644 --- a/tests/test_csrf_middleware.py +++ b/tests/test_csrf_middleware.py @@ -266,6 +266,56 @@ async def test_bearer_with_cookie_does_not_bypass(): assert await _run_middleware(scope) == ("blocked", 403) +@pytest.mark.asyncio +async def test_origin_scheme_must_match(): + # http Origin against an https request must be blocked even when host matches. + scope = _http_scope({"origin": "http://example.com", "host": "example.com"}) + scope["scheme"] = "https" + assert await _run_middleware(scope) == ("blocked", 403) + + +@pytest.mark.asyncio +async def test_origin_port_must_match(): + scope = _http_scope({"origin": "http://example.com:8001", "host": "example.com"}) + scope["scheme"] = "http" + assert await _run_middleware(scope) == ("blocked", 403) + + +@pytest.mark.asyncio +async def test_origin_default_port_normalized(): + # http://example.com:80 == http://example.com + scope = _http_scope({"origin": "http://example.com:80", "host": "example.com"}) + scope["scheme"] = "http" + assert await _run_middleware(scope) == ("allowed",) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "headers", + [ + {"authorization": " ", "host": "example.com", "origin": "http://evil"}, + {"origin": "http://example.com:notaport", "host": "example.com"}, + {"origin": "not-a-url", "host": "example.com"}, + ], +) +async def test_malformed_headers_do_not_500(headers): + # Should be a clean 403, not an unhandled exception. + result = await _run_middleware(_http_scope(headers)) + assert result[0] == "blocked" + assert result[1] == 403 + + +@pytest.mark.asyncio +async def test_uppercase_header_names_normalized(): + # ASGI servers should lowercase, but middleware normalizes defensively. + scope = { + "type": "http", + "method": "POST", + "headers": [(b"Sec-Fetch-Site", b"same-origin")], + } + assert await _run_middleware(scope) == ("allowed",) + + def test_legacy_skip_csrf_hookimpl_does_not_break_loading(): # Plugins that still define skip_csrf must load cleanly - pluggy ignores # unknown hook implementations - even though the hook is no longer