From b077e63dc6255d154ede16df1a507b09ba6075b1 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 13:44:48 -0800 Subject: [PATCH] Ported test_api.py app_client test to ds_client, refs #1959 --- datasette/app.py | 2 +- pytest.ini | 1 + tests/conftest.py | 39 ++++++++++ tests/test_api.py | 190 +++++++++++++++++++++++++++------------------- 4 files changed, 154 insertions(+), 78 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index f3cb8876..b770b469 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -281,7 +281,7 @@ class Datasette: raise self.crossdb = crossdb self.nolock = nolock - if memory or crossdb or not self.files: + if memory or crossdb or (not self.files and memory is not False): self.add_database( Database(self, is_mutable=False, is_memory=True), name="_memory" ) diff --git a/pytest.ini b/pytest.ini index 559e518c..0bcb0d1e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -8,4 +8,5 @@ filterwarnings= ignore:.*current_task.*:PendingDeprecationWarning markers = serial: tests to avoid using with pytest-xdist + ds_client: tests using the ds_client fixture asyncio_mode = strict diff --git a/tests/conftest.py b/tests/conftest.py index cd735e12..1306c407 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ +import asyncio import httpx import os import pathlib import pytest +import pytest_asyncio import re import subprocess import tempfile @@ -23,6 +25,43 @@ UNDOCUMENTED_PERMISSIONS = { } +@pytest.fixture(scope="session") +def event_loop(): + return asyncio.get_event_loop() + + +@pytest_asyncio.fixture(scope="session") +async def ds_client(): + from datasette.app import Datasette + from .fixtures import METADATA, PLUGINS_DIR + + ds = Datasette( + memory=False, + metadata=METADATA, + plugins_dir=PLUGINS_DIR, + settings={ + "default_page_size": 50, + "max_returned_rows": 100, + "sql_time_limit_ms": 200, + # Default is 3 but this results in "too many open files" + # errors when running the full test suite: + "num_sql_threads": 1, + }, + ) + from .fixtures import TABLES, TABLE_PARAMETERIZED_SQL + + db = ds.add_memory_database("fixtures") + + def prepare(conn): + conn.executescript(TABLES) + for sql, params in TABLE_PARAMETERIZED_SQL: + with conn: + conn.execute(sql, params) + + await db.execute_write_fn(prepare) + return ds.client + + def pytest_report_header(config): return "SQLite: {}".format( sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0] diff --git a/tests/test_api.py b/tests/test_api.py index 5f2a6ea6..d0ffb05e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,12 +23,15 @@ import sys import urllib -def test_homepage(app_client): - response = app_client.get("/.json") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_homepage(ds_client): + response = await ds_client.get("/.json") + assert response.status_code == 200 assert "application/json; charset=utf-8" == response.headers["content-type"] - assert response.json.keys() == {"fixtures": 0}.keys() - d = response.json["fixtures"] + data = response.json() + assert data.keys() == {"fixtures": 0}.keys() + d = data["fixtures"] assert d["name"] == "fixtures" assert d["tables_count"] == 24 assert len(d["tables_and_views_truncated"]) == 5 @@ -36,15 +39,17 @@ def test_homepage(app_client): # 4 hidden FTS tables + no_primary_key (hidden in metadata) assert d["hidden_tables_count"] == 6 # 201 in no_primary_key, plus 6 in other hidden tables: - assert d["hidden_table_rows_sum"] == 207, response.json + assert d["hidden_table_rows_sum"] == 207, data assert d["views_count"] == 4 -def test_homepage_sort_by_relationships(app_client): - response = app_client.get("/.json?_sort=relationships") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_homepage_sort_by_relationships(ds_client): + response = await ds_client.get("/.json?_sort=relationships") + assert response.status_code == 200 tables = [ - t["name"] for t in response.json["fixtures"]["tables_and_views_truncated"] + t["name"] for t in response.json()["fixtures"]["tables_and_views_truncated"] ] assert tables == [ "simple_primary_key", @@ -55,10 +60,12 @@ def test_homepage_sort_by_relationships(app_client): ] -def test_database_page(app_client): - response = app_client.get("/fixtures.json") - assert response.status == 200 - data = response.json +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_database_page(ds_client): + response = await ds_client.get("/fixtures.json") + assert response.status_code == 200 + data = response.json() assert data["database"] == "fixtures" assert data["tables"] == [ { @@ -633,11 +640,13 @@ def test_database_page_for_database_with_dot_in_name(app_client_with_dot): assert response.status == 200 -def test_custom_sql(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql(ds_client): + response = await ds_client.get( "/fixtures.json?sql=select+content+from+simple_primary_key&_shape=objects" ) - data = response.json + data = response.json() assert {"sql": "select content from simple_primary_key", "params": {}} == data[ "query" ] @@ -673,41 +682,51 @@ def test_sql_time_limit(app_client_shorter_time_limit): } -def test_custom_sql_time_limit(app_client): - response = app_client.get("/fixtures.json?sql=select+sleep(0.01)") - assert 200 == response.status - response = app_client.get("/fixtures.json?sql=select+sleep(0.01)&_timelimit=5") - assert 400 == response.status - assert "SQL Interrupted" == response.json["title"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql_time_limit(ds_client): + response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)") + assert response.status_code == 200 + response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)&_timelimit=5") + assert response.status_code == 400 + assert response.json()["title"] == "SQL Interrupted" -def test_invalid_custom_sql(app_client): - response = app_client.get("/fixtures.json?sql=.schema") - assert response.status == 400 - assert response.json["ok"] is False - assert "Statement must be a SELECT" == response.json["error"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_invalid_custom_sql(ds_client): + response = await ds_client.get("/fixtures.json?sql=.schema") + assert response.status_code == 400 + assert response.json()["ok"] is False + assert "Statement must be a SELECT" == response.json()["error"] -def test_row(app_client): - response = app_client.get("/fixtures/simple_primary_key/1.json?_shape=objects") - assert response.status == 200 - assert [{"id": "1", "content": "hello"}] == response.json["rows"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row(ds_client): + response = await ds_client.get("/fixtures/simple_primary_key/1.json?_shape=objects") + assert response.status_code == 200 + assert response.json()["rows"] == [{"id": "1", "content": "hello"}] -def test_row_strange_table_name(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row_strange_table_name(ds_client): + response = await ds_client.get( "/fixtures/table~2Fwith~2Fslashes~2Ecsv/3.json?_shape=objects" ) - assert response.status == 200 - assert [{"pk": "3", "content": "hey"}] == response.json["rows"] + assert response.status_code == 200 + assert response.json()["rows"] == [{"pk": "3", "content": "hey"}] -def test_row_foreign_key_tables(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row_foreign_key_tables(ds_client): + response = await ds_client.get( "/fixtures/simple_primary_key/1.json?_extras=foreign_key_tables" ) - assert response.status == 200 - assert response.json["foreign_key_tables"] == [ + assert response.status_code == 200 + assert response.json()["foreign_key_tables"] == [ { "other_table": "foreign_key_references", "column": "id", @@ -762,47 +781,58 @@ def test_databases_json(app_client_two_attached_databases_one_immutable): assert False == fixtures_database["is_memory"] -def test_metadata_json(app_client): - response = app_client.get("/-/metadata.json") - assert METADATA == response.json +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_metadata_json(ds_client): + response = await ds_client.get("/-/metadata.json") + assert response.json() == METADATA -def test_threads_json(app_client): - response = app_client.get("/-/threads.json") +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_threads_json(ds_client): + response = await ds_client.get("/-/threads.json") expected_keys = {"threads", "num_threads"} if sys.version_info >= (3, 7, 0): expected_keys.update({"tasks", "num_tasks"}) - assert expected_keys == set(response.json.keys()) + assert set(response.json().keys()) == expected_keys -def test_plugins_json(app_client): - response = app_client.get("/-/plugins.json") - assert EXPECTED_PLUGINS == sorted(response.json, key=lambda p: p["name"]) +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_plugins_json(ds_client): + response = await ds_client.get("/-/plugins.json") + assert EXPECTED_PLUGINS == sorted(response.json(), key=lambda p: p["name"]) # Try with ?all=1 - response = app_client.get("/-/plugins.json?all=1") - names = {p["name"] for p in response.json} + response = await ds_client.get("/-/plugins.json?all=1") + names = {p["name"] for p in response.json()} assert names.issuperset(p["name"] for p in EXPECTED_PLUGINS) assert names.issuperset(DEFAULT_PLUGINS) -def test_versions_json(app_client): - response = app_client.get("/-/versions.json") - assert "python" in response.json - assert "3.0" == response.json.get("asgi") - assert "version" in response.json["python"] - assert "full" in response.json["python"] - assert "datasette" in response.json - assert "version" in response.json["datasette"] - assert response.json["datasette"]["version"] == __version__ - assert "sqlite" in response.json - assert "version" in response.json["sqlite"] - assert "fts_versions" in response.json["sqlite"] - assert "compile_options" in response.json["sqlite"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_versions_json(ds_client): + response = await ds_client.get("/-/versions.json") + data = response.json() + assert "python" in data + assert "3.0" == data.get("asgi") + assert "version" in data["python"] + assert "full" in data["python"] + assert "datasette" in data + assert "version" in data["datasette"] + assert data["datasette"]["version"] == __version__ + assert "sqlite" in data + assert "version" in data["sqlite"] + assert "fts_versions" in data["sqlite"] + assert "compile_options" in data["sqlite"] -def test_settings_json(app_client): - response = app_client.get("/-/settings.json") - assert { +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_settings_json(ds_client): + response = await ds_client.get("/-/settings.json") + assert response.json() == { "default_page_size": 50, "default_facet_size": 30, "facet_suggest_time_limit_ms": 50, @@ -825,9 +855,11 @@ def test_settings_json(app_client): "template_debug": False, "trace_debug": False, "base_url": "/", - } == response.json + } +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "path,expected_redirect", ( @@ -835,9 +867,9 @@ def test_settings_json(app_client): ("/-/config", "/-/settings"), ), ) -def test_config_redirects_to_settings(app_client, path, expected_redirect): - response = app_client.get(path) - assert response.status == 301 +async def test_config_redirects_to_settings(ds_client, path, expected_redirect): + response = await ds_client.get(path) + assert response.status_code == 301 assert response.headers["Location"] == expected_redirect @@ -846,6 +878,8 @@ test_json_columns_default_expected = [ ] +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "extra_args,expected", [ @@ -859,15 +893,15 @@ test_json_columns_default_expected = [ ), ], ) -def test_json_columns(app_client, extra_args, expected): +async def test_json_columns(ds_client, extra_args, expected): sql = """ select 1 as intval, "s" as strval, 0.5 as floatval, '{"foo": "bar"}' as jsonval """ path = "/fixtures.json?" + urllib.parse.urlencode({"sql": sql, "_shape": "array"}) path += extra_args - response = app_client.get(path) - assert expected == response.json + response = await ds_client.get(path) + assert response.json() == expected def test_config_cache_size(app_client_larger_cache_size): @@ -966,9 +1000,11 @@ def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): assert response.json["filtered_table_rows_count"] == 100 -def test_http_options_request(app_client): - response = app_client.request("/fixtures", method="OPTIONS") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_http_options_request(ds_client): + response = await ds_client.options("/fixtures") + assert response.status_code == 200 assert response.text == "ok"