mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Respect --cors for error pages, closes #453
This commit is contained in:
parent
9683aeb239
commit
45c83b4c35
2 changed files with 21 additions and 2 deletions
|
|
@ -905,11 +905,14 @@ class Datasette:
|
||||||
{"ok": False, "error": message, "status": status, "title": title}
|
{"ok": False, "error": message, "status": status, "title": title}
|
||||||
)
|
)
|
||||||
if request is not None and request.path.split("?")[0].endswith(".json"):
|
if request is not None and request.path.split("?")[0].endswith(".json"):
|
||||||
return response.json(info, status=status)
|
r = response.json(info, status=status)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
template = self.jinja_env.select_template(templates)
|
template = self.jinja_env.select_template(templates)
|
||||||
return response.html(template.render(info), status=status)
|
r = response.html(template.render(info), status=status)
|
||||||
|
if self.cors:
|
||||||
|
r.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
|
return r
|
||||||
|
|
||||||
# First time server starts up, calculate table counts for immutable databases
|
# First time server starts up, calculate table counts for immutable databases
|
||||||
@app.listener("before_server_start")
|
@app.listener("before_server_start")
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from .fixtures import ( # noqa
|
||||||
app_client_shorter_time_limit,
|
app_client_shorter_time_limit,
|
||||||
app_client_larger_cache_size,
|
app_client_larger_cache_size,
|
||||||
app_client_returned_rows_matches_page_size,
|
app_client_returned_rows_matches_page_size,
|
||||||
|
app_client_with_cors,
|
||||||
app_client_with_dot,
|
app_client_with_dot,
|
||||||
generate_compound_rows,
|
generate_compound_rows,
|
||||||
generate_sortable_rows,
|
generate_sortable_rows,
|
||||||
|
|
@ -1474,3 +1475,18 @@ def test_trace(app_client):
|
||||||
assert isinstance(traces["num_traces"], int)
|
assert isinstance(traces["num_traces"], int)
|
||||||
assert isinstance(traces["traces"], dict)
|
assert isinstance(traces["traces"], dict)
|
||||||
assert len(traces["traces"]["queries"]) == traces["num_traces"]
|
assert len(traces["traces"]["queries"]) == traces["num_traces"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path,status_code",
|
||||||
|
[
|
||||||
|
("/fixtures.json", 200),
|
||||||
|
("/fixtures/no_primary_key.json", 200),
|
||||||
|
# A 400 invalid SQL query should still have the header:
|
||||||
|
("/fixtures.json?sql=select+blah", 400),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_cors(app_client_with_cors, path, status_code):
|
||||||
|
response = app_client_with_cors.get(path)
|
||||||
|
assert response.status == status_code
|
||||||
|
assert "*" == response.headers["Access-Control-Allow-Origin"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue