diff --git a/datasette/app.py b/datasette/app.py index 1f69c2b3..52c5e629 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -46,6 +46,7 @@ from .database import Database, QueryInterrupted from .utils import ( PrefixedUrlString, StartupError, + add_cors_headers, async_call_with_supported_arguments, await_me_maybe, call_with_supported_arguments, @@ -1321,7 +1322,7 @@ class DatasetteRouter: ) headers = {} if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(headers) if request.path.split("?")[0].endswith(".json"): await asgi_send_json(send, info, status=status, headers=headers) else: diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 70ac8976..c339113c 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1089,3 +1089,8 @@ async def derive_named_parameters(db, sql): return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"] except sqlite3.DatabaseError: return possible_params + + +def add_cors_headers(headers): + headers["Access-Control-Allow-Origin"] = "*" + headers["Access-Control-Allow-Headers"] = "Authorization" diff --git a/datasette/views/base.py b/datasette/views/base.py index 3333781c..01e90220 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -11,6 +11,7 @@ import pint from datasette import __version__ from datasette.database import QueryInterrupted from datasette.utils import ( + add_cors_headers, await_me_maybe, EscapeHtmlWriter, InvalidSql, @@ -163,7 +164,7 @@ class DataView(BaseView): async def options(self, request, *args, **kwargs): r = Response.text("ok") if self.ds.cors: - r.headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(r.headers) return r def redirect(self, request, path, forward_querystring=True, remove_args=None): @@ -174,7 +175,7 @@ class DataView(BaseView): r = Response.redirect(path) r.headers["Link"] = f"<{path}>; rel=preload" if self.ds.cors: - r.headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(r.headers) return r async def data(self, request, database, hash, **kwargs): @@ -417,7 +418,7 @@ class DataView(BaseView): headers = {} if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(headers) if request.args.get("_dl", None): if not trace: content_type = "text/csv; charset=utf-8" @@ -643,5 +644,5 @@ class DataView(BaseView): response.headers["Cache-Control"] = ttl_header response.headers["Referrer-Policy"] = "no-referrer" if self.ds.cors: - response.headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(response.headers) return response diff --git a/datasette/views/database.py b/datasette/views/database.py index e3070ce6..affded9b 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -8,6 +8,7 @@ from urllib.parse import parse_qsl, urlencode import markupsafe from datasette.utils import ( + add_cors_headers, await_me_maybe, check_visibility, derive_named_parameters, @@ -176,7 +177,7 @@ class DatabaseDownload(DataView): filepath = db.path headers = {} if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(headers) headers["Transfer-Encoding"] = "chunked" return AsgiFileDownload( filepath, diff --git a/datasette/views/index.py b/datasette/views/index.py index e37643f9..18454759 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -1,7 +1,7 @@ import hashlib import json -from datasette.utils import check_visibility, CustomJSONEncoder +from datasette.utils import add_cors_headers, check_visibility, CustomJSONEncoder from datasette.utils.asgi import Response from datasette.version import __version__ @@ -129,7 +129,7 @@ class IndexView(BaseView): if as_format: headers = {} if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(headers) return Response( json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder), content_type="application/json; charset=utf-8", diff --git a/datasette/views/special.py b/datasette/views/special.py index 9750dd06..3cb626a5 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -1,6 +1,6 @@ import json from datasette.utils.asgi import Response, Forbidden -from datasette.utils import actor_matches_allow +from datasette.utils import actor_matches_allow, add_cors_headers from .base import BaseView import secrets @@ -23,7 +23,7 @@ class JsonDataView(BaseView): if as_format: headers = {} if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" + add_cors_headers(headers) return Response( json.dumps(data), content_type="application/json; charset=utf-8", diff --git a/tests/test_api.py b/tests/test_api.py index 38d1ba08..311ae464 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1955,7 +1955,8 @@ def test_trace(trace_debug): def test_cors(app_client_with_cors, path, status_code): response = app_client_with_cors.get(path) assert response.status == status_code - assert "*" == response.headers["Access-Control-Allow-Origin"] + assert response.headers["Access-Control-Allow-Origin"] == "*" + assert response.headers["Access-Control-Allow-Headers"] == "Authorization" @pytest.mark.parametrize(