From e2864fc895603c1fdf03d3c55812a0a6e56779ff Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 26 May 2026 15:21:09 -0700 Subject: [PATCH] test_stored_queries.py --- tests/test_stored_queries.py | 473 +++++++++++++++++++++++++++++++++++ 1 file changed, 473 insertions(+) create mode 100644 tests/test_stored_queries.py diff --git a/tests/test_stored_queries.py b/tests/test_stored_queries.py new file mode 100644 index 00000000..2c648d5f --- /dev/null +++ b/tests/test_stored_queries.py @@ -0,0 +1,473 @@ +from bs4 import BeautifulSoup as Soup +from asgiref.sync import async_to_sync +import json +import pytest +import re +from .fixtures import make_app_client + + +def update_query(client, name, **kwargs): + async_to_sync(client.ds.invoke_startup)() + async_to_sync(client.ds.update_query)("data", name, **kwargs) + + +@pytest.fixture +def stored_write_client(tmpdir): + template_dir = tmpdir / "stored_write_templates" + template_dir.mkdir() + (template_dir / "query-data-update_name.html").write_text( + """ + {% extends "query.html" %} + {% block content %}!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!{{ super() }}{% endblock %} + """, + "utf-8", + ) + with make_app_client( + extra_databases={"data.db": "create table names (name text)"}, + template_dir=str(template_dir), + config={ + "databases": { + "data": { + "queries": { + "stored_read": {"sql": "select * from names"}, + "add_name": { + "sql": "insert into names (name) values (:name)", + "write": True, + "on_success_redirect": "/data/add_name?success", + }, + "add_name_specify_id": { + "sql": "insert into names (rowid, name) values (:rowid, :name)", + "on_success_message_sql": "select 'Name added: ' || :name || ' with rowid ' || :rowid", + "write": True, + "on_error_redirect": "/data/add_name_specify_id?error", + }, + "add_name_specify_id_with_error_in_on_success_message_sql": { + "sql": "insert into names (rowid, name) values (:rowid, :name)", + "on_success_message_sql": "select this is bad SQL", + "write": True, + }, + "delete_name": { + "sql": "delete from names where rowid = :rowid", + "write": True, + "on_success_message": "Name deleted", + "allow": {"id": "root"}, + }, + "update_name": { + "sql": "update names set name = :name where rowid = :rowid", + "params": ["rowid", "name", "extra"], + "write": True, + }, + } + } + } + }, + ) as client: + yield client + + +@pytest.fixture +def stored_write_immutable_client(): + with make_app_client( + is_immutable=True, + config={ + "databases": { + "fixtures": { + "queries": { + "add": { + "sql": "insert into sortable (text) values (:text)", + "write": True, + }, + } + } + } + }, + ) as client: + yield client + + +@pytest.mark.asyncio +async def test_stored_query_with_named_parameter(ds_client): + response = await ds_client.get( + "/fixtures/neighborhood_search.json?text=town&_shape=arrays" + ) + assert response.json()["rows"] == [ + ["Corktown", "Detroit", "MI"], + ["Downtown", "Los Angeles", "CA"], + ["Downtown", "Detroit", "MI"], + ["Greektown", "Detroit", "MI"], + ["Koreatown", "Los Angeles", "CA"], + ["Mexicantown", "Detroit", "MI"], + ] + + +def test_insert(stored_write_client): + response = stored_write_client.post( + "/data/add_name", + {"name": "Hello"}, + csrftoken_from=True, + cookies={"foo": "bar"}, + ) + messages = stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert messages == [["Query executed, 1 row affected", 1]] + assert response.status == 302 + assert response.headers["Location"] == "/data/add_name?success" + + +def test_insert_blocked_cross_site(stored_write_client): + # A cross-site POST (browser-originated) must be blocked + response = stored_write_client.post( + "/data/add_name", + {"name": "Hello"}, + headers={"sec-fetch-site": "cross-site"}, + ) + assert 403 == response.status + + +def test_insert_no_cookies_no_csrf(stored_write_client): + response = stored_write_client.post("/data/add_name", {"name": "Hello"}) + assert 302 == response.status + assert "/data/add_name?success" == response.headers["Location"] + + +def test_custom_success_message(stored_write_client): + response = stored_write_client.post( + "/data/delete_name", + {"rowid": 1}, + cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, + csrftoken_from=True, + ) + assert 302 == response.status + messages = stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert [["Name deleted", 1]] == messages + + +def test_insert_error(stored_write_client): + stored_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True) + response = stored_write_client.post( + "/data/add_name_specify_id", + {"rowid": 1, "name": "Should fail"}, + csrftoken_from=True, + ) + assert 302 == response.status + assert "/data/add_name_specify_id?error" == response.headers["Location"] + messages = stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert [["UNIQUE constraint failed: names.rowid", 3]] == messages + # How about with a custom error message? + update_query(stored_write_client, "add_name_specify_id", on_error_message="ERROR") + response = stored_write_client.post( + "/data/add_name_specify_id", + {"rowid": 1, "name": "Should fail"}, + csrftoken_from=True, + ) + assert [["ERROR", 3]] == stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + + +def test_on_success_message_sql(stored_write_client): + response = stored_write_client.post( + "/data/add_name_specify_id", + {"rowid": 5, "name": "Should be OK"}, + csrftoken_from=True, + ) + assert response.status == 302 + assert response.headers["Location"] == "/data/add_name_specify_id" + messages = stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert messages == [["Name added: Should be OK with rowid 5", 1]] + + +def test_error_in_on_success_message_sql(stored_write_client): + response = stored_write_client.post( + "/data/add_name_specify_id_with_error_in_on_success_message_sql", + {"rowid": 1, "name": "Should fail"}, + csrftoken_from=True, + ) + messages = stored_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert messages == [ + ["Error running on_success_message_sql: no such column: bad", 3] + ] + + +def test_custom_params(stored_write_client): + response = stored_write_client.get("/data/update_name?extra=foo") + assert ( + '' + in response.text + ) + + +def test_stored_query_pages_no_vary_header(stored_write_client): + # These pages no longer embed per-cookie CSRF tokens, so they must not + # set Vary: Cookie - they should be cacheable across users. + assert "vary" not in stored_write_client.get("/data").headers + assert "vary" not in stored_write_client.get("/data/update_name").headers + + +def test_json_post_body(stored_write_client): + response = stored_write_client.post( + "/data/add_name", + body=json.dumps({"name": ["Hello", "there"]}), + ) + assert 302 == response.status + assert "/data/add_name?success" == response.headers["Location"] + rows = stored_write_client.get("/data/names.json?_shape=array").json + assert rows == [{"rowid": 1, "name": "['Hello', 'there']"}] + + +@pytest.mark.parametrize( + "headers,body,querystring", + ( + (None, "name=NameGoesHere", "?_json=1"), + ({"Accept": "application/json"}, "name=NameGoesHere", None), + (None, "name=NameGoesHere&_json=1", None), + (None, '{"name": "NameGoesHere", "_json": 1}', None), + ), +) +def test_json_response(stored_write_client, headers, body, querystring): + response = stored_write_client.post( + "/data/add_name" + (querystring or ""), + body=body, + headers=headers, + ) + assert 200 == response.status + assert response.headers["content-type"] == "application/json; charset=utf-8" + assert response.json == { + "ok": True, + "message": "Query executed, 1 row affected", + "redirect": "/data/add_name?success", + } + rows = stored_write_client.get("/data/names.json?_shape=array").json + assert rows == [{"rowid": 1, "name": "NameGoesHere"}] + + +def test_stored_query_permissions_on_database_page(stored_write_client): + # Without auth shows the five public queries + anon_response = stored_write_client.get("/data.json") + query_names = {q["name"] for q in anon_response.json["queries"]} + assert query_names == { + "add_name_specify_id_with_error_in_on_success_message_sql", + "update_name", + "add_name_specify_id", + "stored_read", + "add_name", + } + assert anon_response.json["queries_more"] is False + + # With auth the database page preview shows the first five queries + response = stored_write_client.get( + "/data.json", + cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, + ) + assert response.status == 200 + query_names_and_private = sorted( + [ + {"name": q["name"], "private": q["private"]} + for q in response.json["queries"] + ], + key=lambda q: q["name"], + ) + assert query_names_and_private == [ + {"name": "add_name", "private": False}, + {"name": "add_name_specify_id", "private": False}, + { + "name": "add_name_specify_id_with_error_in_on_success_message_sql", + "private": False, + }, + {"name": "delete_name", "private": True}, + {"name": "stored_read", "private": False}, + ] + assert response.json["queries_more"] is True + + # The full query list endpoint includes the remaining query + response = stored_write_client.get( + "/data/-/queries.json?_size=10", + cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, + ) + assert response.status == 200 + query_names_and_private = sorted( + [ + {"name": q["name"], "private": q["private"]} + for q in response.json["queries"] + ], + key=lambda q: q["name"], + ) + assert query_names_and_private == [ + {"name": "add_name", "private": False}, + {"name": "add_name_specify_id", "private": False}, + { + "name": "add_name_specify_id_with_error_in_on_success_message_sql", + "private": False, + }, + {"name": "delete_name", "private": True}, + {"name": "stored_read", "private": False}, + {"name": "update_name", "private": False}, + ] + + +def test_stored_query_permissions(stored_write_client): + assert 403 == stored_write_client.get("/data/delete_name").status + assert 200 == stored_write_client.get("/data/update_name").status + cookies = {"ds_actor": stored_write_client.actor_cookie({"id": "root"})} + assert 200 == stored_write_client.get("/data/delete_name", cookies=cookies).status + assert 200 == stored_write_client.get("/data/update_name", cookies=cookies).status + + +@pytest.fixture(scope="session") +def magic_parameters_client(): + with make_app_client( + extra_databases={"data.db": "create table logs (line text)"}, + config={ + "databases": { + "data": { + "queries": { + "runme_post": {"sql": "", "write": True}, + "runme_get": {"sql": ""}, + } + } + } + }, + ) as client: + yield client + + +@pytest.mark.parametrize( + "magic_parameter,expected_re", + [ + ("_actor_id", "root"), + ("_header_host", "localhost"), + ("_header_not_a_thing", ""), + ("_cookie_foo", "bar"), + ("_now_epoch", r"^\d+$"), + ("_now_date_utc", r"^\d{4}-\d{2}-\d{2}$"), + ("_now_datetime_utc", r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$"), + ("_random_chars_1", r"^\w$"), + ("_random_chars_10", r"^\w{10}$"), + ], +) +def test_magic_parameters(magic_parameters_client, magic_parameter, expected_re): + update_query( + magic_parameters_client, + "runme_post", + sql=f"insert into logs (line) values (:{magic_parameter})", + ) + update_query( + magic_parameters_client, + "runme_get", + sql=f"select :{magic_parameter} as result", + ) + cookies = { + "ds_actor": magic_parameters_client.actor_cookie({"id": "root"}), + "foo": "bar", + } + # Test the GET version + get_response = magic_parameters_client.get( + "/data/runme_get.json?_shape=array", cookies=cookies + ) + get_actual = get_response.json[0]["result"] + assert re.match(expected_re, str(get_actual)) + # Test the form + form_response = magic_parameters_client.get("/data/runme_post") + soup = Soup(form_response.body, "html.parser") + # The magic parameter should not be represented as a form field + assert None is soup.find("input", {"name": magic_parameter}) + # Submit the form to create a log line + response = magic_parameters_client.post( + "/data/runme_post?_json=1", {}, csrftoken_from=True, cookies=cookies + ) + assert response.json == { + "ok": True, + "message": "Query executed, 1 row affected", + "redirect": None, + } + post_actual = magic_parameters_client.get( + "/data/logs.json?_sort_desc=rowid&_shape=array" + ).json[0]["line"] + assert re.match(expected_re, post_actual) + + +@pytest.mark.parametrize("use_csrf", [True, False]) +@pytest.mark.parametrize("return_json", [True, False]) +def test_magic_parameters_csrf_json(magic_parameters_client, use_csrf, return_json): + update_query( + magic_parameters_client, + "runme_post", + sql="insert into logs (line) values (:_header_host)", + ) + qs = "" + if return_json: + qs = "?_json=1" + response = magic_parameters_client.post( + f"/data/runme_post{qs}", + {}, + csrftoken_from=use_csrf or None, + ) + if return_json: + assert response.status == 200 + assert response.json["ok"], response.json + else: + assert response.status == 302 + messages = magic_parameters_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert [["Query executed, 1 row affected", 1]] == messages + post_actual = magic_parameters_client.get( + "/data/logs.json?_sort_desc=rowid&_shape=array" + ).json[0]["line"] + assert post_actual == "localhost" + + +def test_magic_parameters_cannot_be_used_in_arbitrary_queries(magic_parameters_client): + response = magic_parameters_client.get( + "/data/-/query.json?sql=select+:_header_host&_shape=array" + ) + assert 400 == response.status + assert response.json["error"].startswith("You did not supply a value for binding") + + +def test_stored_write_custom_template(stored_write_client): + response = stored_write_client.get("/data/update_name") + assert response.status == 200 + assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text + assert ( + "" + in response.text + ) + # And test for link rel=alternate while we're here: + assert ( + '' + in response.text + ) + assert ( + response.headers["link"] + == '; rel="alternate"; type="application/json+datasette"' + ) + + +def test_stored_write_query_disabled_for_immutable_database( + stored_write_immutable_client, +): + response = stored_write_immutable_client.get("/fixtures/add") + assert response.status == 200 + assert ( + "This query cannot be executed because the database is immutable." + in response.text + ) + assert '' in response.text + # Submitting form should get a forbidden error + response = stored_write_immutable_client.post( + "/fixtures/add", + {"text": "text"}, + csrftoken_from=True, + ) + assert response.status == 403 + assert "Database is immutable" in response.text