mirror of
https://github.com/simonw/datasette.git
synced 2026-05-27 12:34:37 +02:00
Normalize headers in CSRF checks, refs #2689
This commit is contained in:
parent
028cc2446f
commit
a973e3ffa1
2 changed files with 110 additions and 11 deletions
|
|
@ -16,6 +16,38 @@ from .utils.asgi import asgi_send
|
|||
|
||||
SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
|
||||
|
||||
DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443}
|
||||
|
||||
|
||||
def _normalize_headers(raw_headers):
|
||||
"""Lowercase header names; for duplicates, last value wins."""
|
||||
result = {}
|
||||
for name, value in raw_headers:
|
||||
if isinstance(name, str):
|
||||
name = name.encode("latin-1")
|
||||
if isinstance(value, str):
|
||||
value = value.encode("latin-1")
|
||||
result[name.lower()] = value
|
||||
return result
|
||||
|
||||
|
||||
def _origin_tuple(value):
|
||||
"""
|
||||
Parse an origin-like string into ``(scheme, host, port)`` with default
|
||||
ports filled in. Raises ``ValueError`` for malformed input.
|
||||
"""
|
||||
parsed = urllib.parse.urlsplit(value)
|
||||
scheme = (parsed.scheme or "").lower()
|
||||
host = (parsed.hostname or "").lower()
|
||||
if not scheme or not host:
|
||||
raise ValueError("missing scheme or host in {!r}".format(value))
|
||||
port = parsed.port # may raise ValueError on bad ports
|
||||
if port is None:
|
||||
port = DEFAULT_PORTS.get(scheme)
|
||||
if port is None:
|
||||
raise ValueError("unknown default port for scheme {!r}".format(scheme))
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def _install_legacy_csrftoken(scope):
|
||||
"""
|
||||
|
|
@ -61,19 +93,19 @@ class CrossOriginProtectionMiddleware:
|
|||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = dict(scope.get("headers") or [])
|
||||
headers = _normalize_headers(scope.get("headers") or [])
|
||||
|
||||
authorization = headers.get(b"authorization", b"").decode("latin-1")
|
||||
cookie_header = headers.get(b"cookie")
|
||||
# Bearer-token requests are not ambient browser credentials, so they
|
||||
# are not CSRF-vulnerable. Narrowly exempt them from the header check
|
||||
# before evaluating Sec-Fetch-Site / Origin. Only "Bearer" is exempt;
|
||||
# schemes like Basic or Digest can be browser-managed and ambient.
|
||||
authorization = headers.get(b"authorization", b"").decode("latin-1")
|
||||
cookie_header = headers.get(b"cookie")
|
||||
# If the request also carries a Cookie header, ambient cookie auth
|
||||
# could be in play, so do NOT treat it as exempt.
|
||||
if authorization and not cookie_header:
|
||||
scheme = authorization.split(None, 1)[0].lower()
|
||||
if scheme == "bearer":
|
||||
parts = authorization.split(None, 1)
|
||||
if parts and parts[0].lower() == "bearer":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
|
|
@ -104,12 +136,20 @@ class CrossOriginProtectionMiddleware:
|
|||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Fallback for older browsers: Origin host must match Host header
|
||||
parsed = urllib.parse.urlparse(origin)
|
||||
origin_host = parsed.hostname or ""
|
||||
if parsed.port:
|
||||
origin_host = "{}:{}".format(origin_host, parsed.port)
|
||||
if origin_host == host:
|
||||
# Fallback for older browsers: Origin must match the request's own
|
||||
# scheme + host + port. Compare full origin tuples, not host alone.
|
||||
request_scheme = self._request_scheme(scope)
|
||||
try:
|
||||
origin_tuple = _origin_tuple(origin)
|
||||
expected_tuple = _origin_tuple("{}://{}".format(request_scheme, host))
|
||||
except ValueError:
|
||||
await self._forbid(
|
||||
send,
|
||||
"Malformed Origin {!r} or Host {!r}".format(origin, host),
|
||||
)
|
||||
return
|
||||
|
||||
if origin_tuple == expected_tuple:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
|
|
@ -118,6 +158,15 @@ class CrossOriginProtectionMiddleware:
|
|||
"Origin {!r} does not match Host {!r}".format(origin, host),
|
||||
)
|
||||
|
||||
def _request_scheme(self, scope):
|
||||
if self.datasette is not None:
|
||||
try:
|
||||
if self.datasette.setting("force_https_urls"):
|
||||
return "https"
|
||||
except Exception:
|
||||
pass
|
||||
return scope.get("scheme") or "http"
|
||||
|
||||
async def _forbid(self, send, reason):
|
||||
await asgi_send(
|
||||
send,
|
||||
|
|
|
|||
|
|
@ -266,6 +266,56 @@ async def test_bearer_with_cookie_does_not_bypass():
|
|||
assert await _run_middleware(scope) == ("blocked", 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_origin_scheme_must_match():
|
||||
# http Origin against an https request must be blocked even when host matches.
|
||||
scope = _http_scope({"origin": "http://example.com", "host": "example.com"})
|
||||
scope["scheme"] = "https"
|
||||
assert await _run_middleware(scope) == ("blocked", 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_origin_port_must_match():
|
||||
scope = _http_scope({"origin": "http://example.com:8001", "host": "example.com"})
|
||||
scope["scheme"] = "http"
|
||||
assert await _run_middleware(scope) == ("blocked", 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_origin_default_port_normalized():
|
||||
# http://example.com:80 == http://example.com
|
||||
scope = _http_scope({"origin": "http://example.com:80", "host": "example.com"})
|
||||
scope["scheme"] = "http"
|
||||
assert await _run_middleware(scope) == ("allowed",)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"headers",
|
||||
[
|
||||
{"authorization": " ", "host": "example.com", "origin": "http://evil"},
|
||||
{"origin": "http://example.com:notaport", "host": "example.com"},
|
||||
{"origin": "not-a-url", "host": "example.com"},
|
||||
],
|
||||
)
|
||||
async def test_malformed_headers_do_not_500(headers):
|
||||
# Should be a clean 403, not an unhandled exception.
|
||||
result = await _run_middleware(_http_scope(headers))
|
||||
assert result[0] == "blocked"
|
||||
assert result[1] == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uppercase_header_names_normalized():
|
||||
# ASGI servers should lowercase, but middleware normalizes defensively.
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"headers": [(b"Sec-Fetch-Site", b"same-origin")],
|
||||
}
|
||||
assert await _run_middleware(scope) == ("allowed",)
|
||||
|
||||
|
||||
def test_legacy_skip_csrf_hookimpl_does_not_break_loading():
|
||||
# Plugins that still define skip_csrf must load cleanly - pluggy ignores
|
||||
# unknown hook implementations - even though the hook is no longer
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue