datasette/datasette/utils/asgi.py
Simon Willison 40a37307de
Add request.form() for multipart form data and file uploads
* Add request.form() for multipart form data and file uploads

New Request.form() method that handles both application/x-www-form-urlencoded
and multipart/form-data content types with streaming parsing.

Features:
- Streaming multipart parser that doesn't buffer entire body in memory
- Files spill to disk above 1MB threshold via SpooledTemporaryFile
- files=False (default) discards file content, files=True stores them
- Security limits: max_request_size, max_file_size, max_fields, max_files
- FormData container with dict-like access and getlist() for multiple values
- UploadedFile class with async read(), seek(), filename, content_type, size
- Support for RFC 5987 filename* encoding for international filenames

Uses multipart-form-data-conformance test suite for validation.

* Update views to use request.form() and document new API

- Migrate PermissionsDebugView, MessagesDebugView, and CreateTokenView
  from post_vars() to form()
- Add documentation for request.form(), FormData, and UploadedFile classes

Centralize multipart defaults and expose stricter limits via Request.form().

Enforce header, part, file, and disk space limits even when files are discarded; detect truncated bodies and client disconnects; and move blocking work off the event loop.

Add FormData close/aclose context managers, update internals docs, and expand multipart tests (including len semantics and stricter conformance expectations).
2026-01-28 18:41:03 -08:00

565 lines
18 KiB
Python

import json
from typing import Optional
from datasette.utils import MultiParams, calculate_etag
from datasette.utils.multipart import (
parse_form_data,
MultipartParseError,
FormData,
DEFAULT_MAX_FILE_SIZE,
DEFAULT_MAX_REQUEST_SIZE,
DEFAULT_MAX_FIELDS,
DEFAULT_MAX_FILES,
DEFAULT_MAX_PARTS,
DEFAULT_MAX_FIELD_SIZE,
DEFAULT_MAX_MEMORY_FILE_SIZE,
DEFAULT_MAX_PART_HEADER_BYTES,
DEFAULT_MAX_PART_HEADER_LINES,
DEFAULT_MIN_FREE_DISK_BYTES,
)
from mimetypes import guess_type
from urllib.parse import parse_qs, urlunparse, parse_qsl
from pathlib import Path
from http.cookies import SimpleCookie, Morsel
import aiofiles
import aiofiles.os
import re
# Workaround for adding samesite support to pre 3.8 python
Morsel._reserved["samesite"] = "SameSite"
# Thanks, Starlette:
# https://github.com/encode/starlette/blob/519f575/starlette/responses.py#L17
class Base400(Exception):
status = 400
class NotFound(Base400):
status = 404
class DatabaseNotFound(NotFound):
def __init__(self, database_name):
self.database_name = database_name
super().__init__("Database not found")
class TableNotFound(NotFound):
def __init__(self, database_name, table):
super().__init__("Table not found")
self.database_name = database_name
self.table = table
class RowNotFound(NotFound):
def __init__(self, database_name, table, pk_values):
super().__init__("Row not found")
self.database_name = database_name
self.table_name = table
self.pk_values = pk_values
class Forbidden(Base400):
status = 403
class BadRequest(Base400):
status = 400
SAMESITE_VALUES = ("strict", "lax", "none")
class Request:
def __init__(self, scope, receive):
self.scope = scope
self.receive = receive
def __repr__(self):
return '<asgi.Request method="{}" url="{}">'.format(self.method, self.url)
@property
def method(self):
return self.scope["method"]
@property
def url(self):
return urlunparse(
(self.scheme, self.host, self.path, None, self.query_string, None)
)
@property
def url_vars(self):
return (self.scope.get("url_route") or {}).get("kwargs") or {}
@property
def scheme(self):
return self.scope.get("scheme") or "http"
@property
def headers(self):
return {
k.decode("latin-1").lower(): v.decode("latin-1")
for k, v in self.scope.get("headers") or []
}
@property
def host(self):
return self.headers.get("host") or "localhost"
@property
def cookies(self):
cookies = SimpleCookie()
cookies.load(self.headers.get("cookie", ""))
return {key: value.value for key, value in cookies.items()}
@property
def path(self):
if self.scope.get("raw_path") is not None:
return self.scope["raw_path"].decode("latin-1").partition("?")[0]
else:
path = self.scope["path"]
if isinstance(path, str):
return path
else:
return path.decode("utf-8")
@property
def query_string(self):
return (self.scope.get("query_string") or b"").decode("latin-1")
@property
def full_path(self):
qs = self.query_string
return "{}{}".format(self.path, ("?" + qs) if qs else "")
@property
def args(self):
return MultiParams(parse_qs(qs=self.query_string, keep_blank_values=True))
@property
def actor(self):
return self.scope.get("actor", None)
async def post_body(self):
body = b""
more_body = True
while more_body:
message = await self.receive()
assert message["type"] == "http.request", message
body += message.get("body", b"")
more_body = message.get("more_body", False)
return body
async def post_vars(self):
body = await self.post_body()
return dict(parse_qsl(body.decode("utf-8"), keep_blank_values=True))
async def form(
self,
files: bool = False,
max_file_size: int = DEFAULT_MAX_FILE_SIZE,
max_request_size: int = DEFAULT_MAX_REQUEST_SIZE,
max_fields: int = DEFAULT_MAX_FIELDS,
max_files: int = DEFAULT_MAX_FILES,
max_parts: Optional[int] = DEFAULT_MAX_PARTS,
max_field_size: int = DEFAULT_MAX_FIELD_SIZE,
max_memory_file_size: int = DEFAULT_MAX_MEMORY_FILE_SIZE,
max_part_header_bytes: int = DEFAULT_MAX_PART_HEADER_BYTES,
max_part_header_lines: int = DEFAULT_MAX_PART_HEADER_LINES,
min_free_disk_bytes: int = DEFAULT_MIN_FREE_DISK_BYTES,
) -> FormData:
"""
Parse form data from the request body.
Supports both application/x-www-form-urlencoded and multipart/form-data.
Args:
files: If True, store file uploads; if False (default), discard them
max_file_size: Maximum size per file in bytes (default 50MB)
max_request_size: Maximum total request size in bytes (default 100MB)
max_fields: Maximum number of form fields (default 1000)
max_files: Maximum number of file uploads (default 100)
max_parts: Maximum number of multipart parts (default max_fields + max_files)
max_field_size: Maximum size of a text field value in bytes (default 100KB)
max_memory_file_size: Threshold before files spill to disk (default 1MB)
max_part_header_bytes: Maximum bytes allowed in part headers (default 16KB)
max_part_header_lines: Maximum header lines per part (default 100)
min_free_disk_bytes: Minimum free bytes required in temp dir (default 50MB)
Returns:
FormData object with dict-like access to fields and files.
Use form["key"] for first value, form.getlist("key") for all values.
Raises:
BadRequest: If content-type is missing, unsupported, or parsing fails
"""
content_type = self.headers.get("content-type", "")
if not content_type:
raise BadRequest(
"Missing Content-Type header; expected application/x-www-form-urlencoded "
"or multipart/form-data"
)
try:
return await parse_form_data(
receive=self.receive,
content_type=content_type,
files=files,
max_file_size=max_file_size,
max_request_size=max_request_size,
max_fields=max_fields,
max_files=max_files,
max_parts=max_parts,
max_field_size=max_field_size,
max_memory_file_size=max_memory_file_size,
max_part_header_bytes=max_part_header_bytes,
max_part_header_lines=max_part_header_lines,
min_free_disk_bytes=min_free_disk_bytes,
)
except MultipartParseError as e:
raise BadRequest(str(e))
@classmethod
def fake(cls, path_with_query_string, method="GET", scheme="http", url_vars=None):
"""Useful for constructing Request objects for tests"""
path, _, query_string = path_with_query_string.partition("?")
scope = {
"http_version": "1.1",
"method": method,
"path": path,
"raw_path": path_with_query_string.encode("latin-1"),
"query_string": query_string.encode("latin-1"),
"scheme": scheme,
"type": "http",
}
if url_vars:
scope["url_route"] = {"kwargs": url_vars}
return cls(scope, None)
class AsgiLifespan:
def __init__(self, app, on_startup=None, on_shutdown=None):
self.app = app
on_startup = on_startup or []
on_shutdown = on_shutdown or []
if not isinstance(on_startup or [], list):
on_startup = [on_startup]
if not isinstance(on_shutdown or [], list):
on_shutdown = [on_shutdown]
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
for fn in self.on_startup:
await fn()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.on_shutdown:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
else:
await self.app(scope, receive, send)
class AsgiStream:
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
self.stream_fn = stream_fn
self.status = status
self.headers = headers or {}
self.content_type = content_type
async def asgi_send(self, send):
# Remove any existing content-type header
headers = {k: v for k, v in self.headers.items() if k.lower() != "content-type"}
headers["content-type"] = self.content_type
await send(
{
"type": "http.response.start",
"status": self.status,
"headers": [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in headers.items()
],
}
)
w = AsgiWriter(send)
await self.stream_fn(w)
await send({"type": "http.response.body", "body": b""})
class AsgiWriter:
def __init__(self, send):
self.send = send
async def write(self, chunk):
await self.send(
{
"type": "http.response.body",
"body": chunk.encode("utf-8"),
"more_body": True,
}
)
async def asgi_send_json(send, info, status=200, headers=None):
headers = headers or {}
await asgi_send(
send,
json.dumps(info),
status=status,
headers=headers,
content_type="application/json; charset=utf-8",
)
async def asgi_send_html(send, html, status=200, headers=None):
headers = headers or {}
await asgi_send(
send,
html,
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
)
async def asgi_send_redirect(send, location, status=302):
# Prevent open redirect vulnerability: strip multiple leading slashes
# //example.com would be interpreted as a protocol-relative URL (e.g., https://example.com/)
location = re.sub(r"^/+", "/", location)
await asgi_send(
send,
"",
status=status,
headers={"Location": location},
content_type="text/html; charset=utf-8",
)
async def asgi_send(send, content, status, headers=None, content_type="text/plain"):
await asgi_start(send, status, headers, content_type)
await send({"type": "http.response.body", "body": content.encode("utf-8")})
async def asgi_start(send, status, headers=None, content_type="text/plain"):
headers = headers or {}
# Remove any existing content-type header
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
headers["content-type"] = content_type
await send(
{
"type": "http.response.start",
"status": status,
"headers": [
[key.encode("latin1"), value.encode("latin1")]
for key, value in headers.items()
],
}
)
async def asgi_send_file(
send, filepath, filename=None, content_type=None, chunk_size=4096, headers=None
):
headers = headers or {}
if filename:
headers["content-disposition"] = f'attachment; filename="{filename}"'
first = True
headers["content-length"] = str((await aiofiles.os.stat(str(filepath))).st_size)
async with aiofiles.open(str(filepath), mode="rb") as fp:
if first:
await asgi_start(
send,
200,
headers,
content_type or guess_type(str(filepath))[0] or "text/plain",
)
first = False
more_body = True
while more_body:
chunk = await fp.read(chunk_size)
more_body = len(chunk) == chunk_size
await send(
{"type": "http.response.body", "body": chunk, "more_body": more_body}
)
def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None):
root_path = Path(root_path)
static_headers = {}
if headers:
static_headers = headers.copy()
async def inner_static(request, send):
path = request.scope["url_route"]["kwargs"]["path"]
headers = static_headers.copy()
try:
full_path = (root_path / path).resolve().absolute()
except FileNotFoundError:
await asgi_send_html(send, "404: Directory not found", 404)
return
if full_path.is_dir():
await asgi_send_html(send, "403: Directory listing is not allowed", 403)
return
# Ensure full_path is within root_path to avoid weird "../" tricks
try:
full_path.relative_to(root_path.resolve())
except ValueError:
await asgi_send_html(send, "404: Path not inside root path", 404)
return
try:
# Calculate ETag for filepath
etag = await calculate_etag(full_path, chunk_size=chunk_size)
headers["ETag"] = etag
if_none_match = request.headers.get("if-none-match")
if if_none_match and if_none_match == etag:
return await asgi_send(send, "", 304)
await asgi_send_file(
send, full_path, chunk_size=chunk_size, headers=headers
)
except FileNotFoundError:
await asgi_send_html(send, "404: File not found", 404)
return
return inner_static
class Response:
def __init__(self, body=None, status=200, headers=None, content_type="text/plain"):
self.body = body
self.status = status
self.headers = headers or {}
self._set_cookie_headers = []
self.content_type = content_type
async def asgi_send(self, send):
headers = {}
headers.update(self.headers)
headers["content-type"] = self.content_type
raw_headers = [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in headers.items()
]
for set_cookie in self._set_cookie_headers:
raw_headers.append([b"set-cookie", set_cookie.encode("utf-8")])
await send(
{
"type": "http.response.start",
"status": self.status,
"headers": raw_headers,
}
)
body = self.body
if not isinstance(body, bytes):
body = body.encode("utf-8")
await send({"type": "http.response.body", "body": body})
def set_cookie(
self,
key,
value="",
max_age=None,
expires=None,
path="/",
domain=None,
secure=False,
httponly=False,
samesite="lax",
):
assert samesite in SAMESITE_VALUES, "samesite should be one of {}".format(
SAMESITE_VALUES
)
cookie = SimpleCookie()
cookie[key] = value
for prop_name, prop_value in (
("max_age", max_age),
("expires", expires),
("path", path),
("domain", domain),
("samesite", samesite),
):
if prop_value is not None:
cookie[key][prop_name.replace("_", "-")] = prop_value
for prop_name, prop_value in (("secure", secure), ("httponly", httponly)):
if prop_value:
cookie[key][prop_name] = True
self._set_cookie_headers.append(cookie.output(header="").strip())
@classmethod
def html(cls, body, status=200, headers=None):
return cls(
body,
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
)
@classmethod
def text(cls, body, status=200, headers=None):
return cls(
str(body),
status=status,
headers=headers,
content_type="text/plain; charset=utf-8",
)
@classmethod
def json(cls, body, status=200, headers=None, default=None):
return cls(
json.dumps(body, default=default),
status=status,
headers=headers,
content_type="application/json; charset=utf-8",
)
@classmethod
def redirect(cls, path, status=302, headers=None):
headers = headers or {}
headers["Location"] = path
return cls("", status=status, headers=headers)
class AsgiFileDownload:
def __init__(
self,
filepath,
filename=None,
content_type="application/octet-stream",
headers=None,
):
self.headers = headers or {}
self.filepath = filepath
self.filename = filename
self.content_type = content_type
async def asgi_send(self, send):
return await asgi_send_file(
send,
self.filepath,
filename=self.filename,
content_type=self.content_type,
headers=self.headers,
)
class AsgiRunOnFirstRequest:
def __init__(self, asgi, on_startup):
assert isinstance(on_startup, list)
self.asgi = asgi
self.on_startup = on_startup
self._started = False
async def __call__(self, scope, receive, send):
if not self._started:
self._started = True
for hook in self.on_startup:
await hook()
return await self.asgi(scope, receive, send)