From 6e1e815c7881abe836d573b18ed2c5bb3e5b699e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 14 Dec 2022 18:41:30 -0800 Subject: [PATCH 001/603] It's an update-or-insert --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index df4b2cb6..aec13e27 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,7 +11,7 @@ Changelog The third Datasette 1.0 alpha release adds upsert support to the JSON API, plus the ability to specify finely grained permissions when creating an API token. -- New ``/db/table/-/upsert`` API, :ref:`documented here `. upsert is an update-or-replace: existing rows will have specified keys updated, but if no row matches the incoming primary key a brand new row will be inserted instead. (:issue:`1878`) +- New ``/db/table/-/upsert`` API, :ref:`documented here `. upsert is an update-or-insert: existing rows will have specified keys updated, but if no row matches the incoming primary key a brand new row will be inserted instead. (:issue:`1878`) - New :ref:`plugin_register_permissions` plugin hook. Plugins can now register named permissions, which will then be listed in various interfaces that show available permissions. (:issue:`1940`) - The ``/db/-/create`` API for :ref:`creating a table ` now accepts ``"ignore": true`` and ``"replace": true`` options when called with the ``"rows"`` property that creates a new table based on an example set of rows. This means the API can be called multiple times with different rows, setting rules for what should happen if a primary key collides with an existing row. (:issue:`1927`) - Arbitrary permissions can now be configured at the instance, database and resource (table, SQL view or canned query) level in Datasette's :ref:`metadata` JSON and YAML files. The new ``"permissions"`` key can be used to specify which actors should have which permissions. See :ref:`authentication_permissions_other` for details. (:issue:`1636`) From e054704fb64d1f23154ec43b81b6c9481ff8202f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 14 Dec 2022 21:38:20 -0800 Subject: [PATCH 002/603] Added missing rST label --- docs/internals.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/internals.rst b/docs/internals.rst index 4b7a440c..4b82e11c 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -419,6 +419,8 @@ The following example runs three checks in a row, similar to :ref:`datasette_ens ], ) +.. _datasette_create_token: + .create_token(actor_id, expires_after=None, restrict_all=None, restrict_database=None, restrict_resource=None) -------------------------------------------------------------------------------------------------------------- From dc18f62089e5672d03176f217d7840cdafa5c447 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 09:34:07 -0800 Subject: [PATCH 003/603] Replace AsgiLifespan with AsgiRunOnFirstRequest, refs #1955 --- datasette/app.py | 20 ++--------- datasette/utils/asgi.py | 44 ++++++++---------------- docs/plugin_hooks.rst | 5 +-- docs/testing_plugins.rst | 2 +- tests/test_internals_datasette_client.py | 1 - 5 files changed, 22 insertions(+), 50 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index f3cb8876..7e682498 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -69,8 +69,6 @@ from .utils import ( row_sql_params_pks, ) from .utils.asgi import ( - AsgiLifespan, - Base400, Forbidden, NotFound, DatabaseNotFound, @@ -78,11 +76,10 @@ from .utils.asgi import ( RowNotFound, Request, Response, + AsgiRunOnFirstRequest, asgi_static, asgi_send, asgi_send_file, - asgi_send_html, - asgi_send_json, asgi_send_redirect, ) from .utils.internal_db import init_internal_db, populate_schema_tables @@ -1420,7 +1417,7 @@ class Datasette: async def setup_db(): # First time server starts up, calculate table counts for immutable databases - for dbname, database in self.databases.items(): + for database in self.databases.values(): if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) @@ -1434,10 +1431,7 @@ class Datasette: ) if self.setting("trace_debug"): asgi = AsgiTracer(asgi) - asgi = AsgiLifespan( - asgi, - on_startup=setup_db, - ) + asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) for wrapper in pm.hook.asgi_wrapper(datasette=self): asgi = wrapper(asgi) return asgi @@ -1726,42 +1720,34 @@ class DatasetteClient: return path async def get(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.get(self._fix(path), **kwargs) async def options(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.options(self._fix(path), **kwargs) async def head(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.head(self._fix(path), **kwargs) async def post(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.post(self._fix(path), **kwargs) async def put(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.put(self._fix(path), **kwargs) async def patch(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.patch(self._fix(path), **kwargs) async def delete(self, path, **kwargs): - await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.delete(self._fix(path), **kwargs) async def request(self, method, path, **kwargs): - await self.ds.invoke_startup() avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) async with httpx.AsyncClient(app=self.app) as client: return await client.request( diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index f080df91..56690251 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -156,35 +156,6 @@ class Request: return cls(scope, None) -class AsgiLifespan: - def __init__(self, app, on_startup=None, on_shutdown=None): - self.app = app - on_startup = on_startup or [] - on_shutdown = on_shutdown or [] - if not isinstance(on_startup or [], list): - on_startup = [on_startup] - if not isinstance(on_shutdown or [], list): - on_shutdown = [on_shutdown] - self.on_startup = on_startup - self.on_shutdown = on_shutdown - - async def __call__(self, scope, receive, send): - if scope["type"] == "lifespan": - while True: - message = await receive() - if message["type"] == "lifespan.startup": - for fn in self.on_startup: - await fn() - await send({"type": "lifespan.startup.complete"}) - elif message["type"] == "lifespan.shutdown": - for fn in self.on_shutdown: - await fn() - await send({"type": "lifespan.shutdown.complete"}) - return - else: - await self.app(scope, receive, send) - - class AsgiStream: def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): self.stream_fn = stream_fn @@ -449,3 +420,18 @@ class AsgiFileDownload: content_type=self.content_type, headers=self.headers, ) + + +class AsgiRunOnFirstRequest: + def __init__(self, asgi, on_startup): + assert isinstance(on_startup, list) + self.asgi = asgi + self.on_startup = on_startup + self._started = False + + async def __call__(self, scope, receive, send): + if not self._started: + self._started = True + for hook in self.on_startup: + await hook() + return await self.asgi(scope, receive, send) diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index f41ca876..cdc73f00 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -902,13 +902,14 @@ Potential use-cases: .. note:: - If you are writing :ref:`unit tests ` for a plugin that uses this hook you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example: + If you are writing :ref:`unit tests ` for a plugin that uses this hook and doesn't exercise Datasette by sending + any simulated requests through it you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example: .. code-block:: python @pytest.mark.asyncio async def test_my_plugin(): - ds = Datasette([], metadata={}) + ds = Datasette() await ds.invoke_startup() # Rest of test goes here diff --git a/docs/testing_plugins.rst b/docs/testing_plugins.rst index 41f50e56..6d2097ad 100644 --- a/docs/testing_plugins.rst +++ b/docs/testing_plugins.rst @@ -80,7 +80,7 @@ Creating a ``Datasette()`` instance like this as useful shortcut in tests, but t This method registers any :ref:`plugin_hook_startup` or :ref:`plugin_hook_prepare_jinja2_environment` plugins that might themselves need to make async calls. -If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - those method calls ensure that ``.invoke_startup()`` has been called for you. +If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - Datasette automatically calls ``invoke_startup()`` the first time it handles a request. .. _testing_plugins_pdb: diff --git a/tests/test_internals_datasette_client.py b/tests/test_internals_datasette_client.py index cbbfa3c3..7a95ed6e 100644 --- a/tests/test_internals_datasette_client.py +++ b/tests/test_internals_datasette_client.py @@ -6,7 +6,6 @@ import pytest_asyncio @pytest_asyncio.fixture async def datasette(app_client): - await app_client.ds.invoke_startup() return app_client.ds From 51ee8caa4a697fa3f4120e93b1c205b714a6cdc7 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 12:51:18 -0800 Subject: [PATCH 004/603] Try running every test at once, refs #1955 --- .github/workflows/test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 886f649a..c4032656 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,8 +33,9 @@ jobs: pip freeze - name: Run tests run: | - pytest -n auto -m "not serial" - pytest -m "serial" + # pytest -n auto -m "not serial" + # pytest -m "serial" + pytest -n auto - name: Check if cog needs to be run run: | cog --check docs/*.rst From 38d28dd958c41e5e7fde3788ba3fdaf2e09eff70 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 13:05:33 -0800 Subject: [PATCH 005/603] Revert "Try running every test at once, refs #1955" This reverts commit 51ee8caa4a697fa3f4120e93b1c205b714a6cdc7. --- .github/workflows/test.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c4032656..886f649a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,9 +33,8 @@ jobs: pip freeze - name: Run tests run: | - # pytest -n auto -m "not serial" - # pytest -m "serial" - pytest -n auto + pytest -n auto -m "not serial" + pytest -m "serial" - name: Check if cog needs to be run run: | cog --check docs/*.rst From 0b68996cc511b3a801f0cd0157bd66332d75f46f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 13:06:45 -0800 Subject: [PATCH 006/603] Revert "Replace AsgiLifespan with AsgiRunOnFirstRequest, refs #1955" This reverts commit dc18f62089e5672d03176f217d7840cdafa5c447. --- datasette/app.py | 20 +++++++++-- datasette/utils/asgi.py | 44 ++++++++++++++++-------- docs/plugin_hooks.rst | 5 ++- docs/testing_plugins.rst | 2 +- tests/test_internals_datasette_client.py | 1 + 5 files changed, 50 insertions(+), 22 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 7e682498..f3cb8876 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -69,6 +69,8 @@ from .utils import ( row_sql_params_pks, ) from .utils.asgi import ( + AsgiLifespan, + Base400, Forbidden, NotFound, DatabaseNotFound, @@ -76,10 +78,11 @@ from .utils.asgi import ( RowNotFound, Request, Response, - AsgiRunOnFirstRequest, asgi_static, asgi_send, asgi_send_file, + asgi_send_html, + asgi_send_json, asgi_send_redirect, ) from .utils.internal_db import init_internal_db, populate_schema_tables @@ -1417,7 +1420,7 @@ class Datasette: async def setup_db(): # First time server starts up, calculate table counts for immutable databases - for database in self.databases.values(): + for dbname, database in self.databases.items(): if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) @@ -1431,7 +1434,10 @@ class Datasette: ) if self.setting("trace_debug"): asgi = AsgiTracer(asgi) - asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) + asgi = AsgiLifespan( + asgi, + on_startup=setup_db, + ) for wrapper in pm.hook.asgi_wrapper(datasette=self): asgi = wrapper(asgi) return asgi @@ -1720,34 +1726,42 @@ class DatasetteClient: return path async def get(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.get(self._fix(path), **kwargs) async def options(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.options(self._fix(path), **kwargs) async def head(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.head(self._fix(path), **kwargs) async def post(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.post(self._fix(path), **kwargs) async def put(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.put(self._fix(path), **kwargs) async def patch(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.patch(self._fix(path), **kwargs) async def delete(self, path, **kwargs): + await self.ds.invoke_startup() async with httpx.AsyncClient(app=self.app) as client: return await client.delete(self._fix(path), **kwargs) async def request(self, method, path, **kwargs): + await self.ds.invoke_startup() avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) async with httpx.AsyncClient(app=self.app) as client: return await client.request( diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index 56690251..f080df91 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -156,6 +156,35 @@ class Request: return cls(scope, None) +class AsgiLifespan: + def __init__(self, app, on_startup=None, on_shutdown=None): + self.app = app + on_startup = on_startup or [] + on_shutdown = on_shutdown or [] + if not isinstance(on_startup or [], list): + on_startup = [on_startup] + if not isinstance(on_shutdown or [], list): + on_shutdown = [on_shutdown] + self.on_startup = on_startup + self.on_shutdown = on_shutdown + + async def __call__(self, scope, receive, send): + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for fn in self.on_startup: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.on_shutdown: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await self.app(scope, receive, send) + + class AsgiStream: def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): self.stream_fn = stream_fn @@ -420,18 +449,3 @@ class AsgiFileDownload: content_type=self.content_type, headers=self.headers, ) - - -class AsgiRunOnFirstRequest: - def __init__(self, asgi, on_startup): - assert isinstance(on_startup, list) - self.asgi = asgi - self.on_startup = on_startup - self._started = False - - async def __call__(self, scope, receive, send): - if not self._started: - self._started = True - for hook in self.on_startup: - await hook() - return await self.asgi(scope, receive, send) diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index cdc73f00..f41ca876 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -902,14 +902,13 @@ Potential use-cases: .. note:: - If you are writing :ref:`unit tests ` for a plugin that uses this hook and doesn't exercise Datasette by sending - any simulated requests through it you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example: + If you are writing :ref:`unit tests ` for a plugin that uses this hook you will need to explicitly call ``await ds.invoke_startup()`` in your tests. An example: .. code-block:: python @pytest.mark.asyncio async def test_my_plugin(): - ds = Datasette() + ds = Datasette([], metadata={}) await ds.invoke_startup() # Rest of test goes here diff --git a/docs/testing_plugins.rst b/docs/testing_plugins.rst index 6d2097ad..41f50e56 100644 --- a/docs/testing_plugins.rst +++ b/docs/testing_plugins.rst @@ -80,7 +80,7 @@ Creating a ``Datasette()`` instance like this as useful shortcut in tests, but t This method registers any :ref:`plugin_hook_startup` or :ref:`plugin_hook_prepare_jinja2_environment` plugins that might themselves need to make async calls. -If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - Datasette automatically calls ``invoke_startup()`` the first time it handles a request. +If you are using ``await datasette.client.get()`` and similar methods then you don't need to worry about this - those method calls ensure that ``.invoke_startup()`` has been called for you. .. _testing_plugins_pdb: diff --git a/tests/test_internals_datasette_client.py b/tests/test_internals_datasette_client.py index 7a95ed6e..cbbfa3c3 100644 --- a/tests/test_internals_datasette_client.py +++ b/tests/test_internals_datasette_client.py @@ -6,6 +6,7 @@ import pytest_asyncio @pytest_asyncio.fixture async def datasette(app_client): + await app_client.ds.invoke_startup() return app_client.ds From 013496862f4d4b441ab61255242b838b24287607 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 16:55:17 -0800 Subject: [PATCH 007/603] Try click.echo() instead This ensures the URL is output correctly when running under Docker. Closes #1958 --- datasette/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasette/cli.py b/datasette/cli.py index b3ae643a..d197925b 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -618,7 +618,7 @@ def serve( url = "http://{}:{}{}?token={}".format( host, port, ds.urls.path("-/auth-token"), ds._root_token ) - print(url) + click.echo(url) if open_browser: if url is None: # Figure out most convenient URL - to table, database or homepage From 5ee954e34b6eb762ccecbdb2be0791d0166fd19c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 17:03:37 -0800 Subject: [PATCH 008/603] Link to annotated release notes for 1.0a2 --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index aec13e27..23eab873 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,6 +11,8 @@ Changelog The third Datasette 1.0 alpha release adds upsert support to the JSON API, plus the ability to specify finely grained permissions when creating an API token. +See `Datasette 1.0a2: Upserts and finely grained permissions `__ for an extended, annotated version of these release notes. + - New ``/db/table/-/upsert`` API, :ref:`documented here `. upsert is an update-or-insert: existing rows will have specified keys updated, but if no row matches the incoming primary key a brand new row will be inserted instead. (:issue:`1878`) - New :ref:`plugin_register_permissions` plugin hook. Plugins can now register named permissions, which will then be listed in various interfaces that show available permissions. (:issue:`1940`) - The ``/db/-/create`` API for :ref:`creating a table ` now accepts ``"ignore": true`` and ``"replace": true`` options when called with the ``"rows"`` property that creates a new table based on an example set of rows. This means the API can be called multiple times with different rows, setting rules for what should happen if a primary key collides with an existing row. (:issue:`1927`) From b077e63dc6255d154ede16df1a507b09ba6075b1 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 13:44:48 -0800 Subject: [PATCH 009/603] Ported test_api.py app_client test to ds_client, refs #1959 --- datasette/app.py | 2 +- pytest.ini | 1 + tests/conftest.py | 39 ++++++++++ tests/test_api.py | 190 +++++++++++++++++++++++++++------------------- 4 files changed, 154 insertions(+), 78 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index f3cb8876..b770b469 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -281,7 +281,7 @@ class Datasette: raise self.crossdb = crossdb self.nolock = nolock - if memory or crossdb or not self.files: + if memory or crossdb or (not self.files and memory is not False): self.add_database( Database(self, is_mutable=False, is_memory=True), name="_memory" ) diff --git a/pytest.ini b/pytest.ini index 559e518c..0bcb0d1e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -8,4 +8,5 @@ filterwarnings= ignore:.*current_task.*:PendingDeprecationWarning markers = serial: tests to avoid using with pytest-xdist + ds_client: tests using the ds_client fixture asyncio_mode = strict diff --git a/tests/conftest.py b/tests/conftest.py index cd735e12..1306c407 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ +import asyncio import httpx import os import pathlib import pytest +import pytest_asyncio import re import subprocess import tempfile @@ -23,6 +25,43 @@ UNDOCUMENTED_PERMISSIONS = { } +@pytest.fixture(scope="session") +def event_loop(): + return asyncio.get_event_loop() + + +@pytest_asyncio.fixture(scope="session") +async def ds_client(): + from datasette.app import Datasette + from .fixtures import METADATA, PLUGINS_DIR + + ds = Datasette( + memory=False, + metadata=METADATA, + plugins_dir=PLUGINS_DIR, + settings={ + "default_page_size": 50, + "max_returned_rows": 100, + "sql_time_limit_ms": 200, + # Default is 3 but this results in "too many open files" + # errors when running the full test suite: + "num_sql_threads": 1, + }, + ) + from .fixtures import TABLES, TABLE_PARAMETERIZED_SQL + + db = ds.add_memory_database("fixtures") + + def prepare(conn): + conn.executescript(TABLES) + for sql, params in TABLE_PARAMETERIZED_SQL: + with conn: + conn.execute(sql, params) + + await db.execute_write_fn(prepare) + return ds.client + + def pytest_report_header(config): return "SQLite: {}".format( sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0] diff --git a/tests/test_api.py b/tests/test_api.py index 5f2a6ea6..d0ffb05e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,12 +23,15 @@ import sys import urllib -def test_homepage(app_client): - response = app_client.get("/.json") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_homepage(ds_client): + response = await ds_client.get("/.json") + assert response.status_code == 200 assert "application/json; charset=utf-8" == response.headers["content-type"] - assert response.json.keys() == {"fixtures": 0}.keys() - d = response.json["fixtures"] + data = response.json() + assert data.keys() == {"fixtures": 0}.keys() + d = data["fixtures"] assert d["name"] == "fixtures" assert d["tables_count"] == 24 assert len(d["tables_and_views_truncated"]) == 5 @@ -36,15 +39,17 @@ def test_homepage(app_client): # 4 hidden FTS tables + no_primary_key (hidden in metadata) assert d["hidden_tables_count"] == 6 # 201 in no_primary_key, plus 6 in other hidden tables: - assert d["hidden_table_rows_sum"] == 207, response.json + assert d["hidden_table_rows_sum"] == 207, data assert d["views_count"] == 4 -def test_homepage_sort_by_relationships(app_client): - response = app_client.get("/.json?_sort=relationships") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_homepage_sort_by_relationships(ds_client): + response = await ds_client.get("/.json?_sort=relationships") + assert response.status_code == 200 tables = [ - t["name"] for t in response.json["fixtures"]["tables_and_views_truncated"] + t["name"] for t in response.json()["fixtures"]["tables_and_views_truncated"] ] assert tables == [ "simple_primary_key", @@ -55,10 +60,12 @@ def test_homepage_sort_by_relationships(app_client): ] -def test_database_page(app_client): - response = app_client.get("/fixtures.json") - assert response.status == 200 - data = response.json +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_database_page(ds_client): + response = await ds_client.get("/fixtures.json") + assert response.status_code == 200 + data = response.json() assert data["database"] == "fixtures" assert data["tables"] == [ { @@ -633,11 +640,13 @@ def test_database_page_for_database_with_dot_in_name(app_client_with_dot): assert response.status == 200 -def test_custom_sql(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql(ds_client): + response = await ds_client.get( "/fixtures.json?sql=select+content+from+simple_primary_key&_shape=objects" ) - data = response.json + data = response.json() assert {"sql": "select content from simple_primary_key", "params": {}} == data[ "query" ] @@ -673,41 +682,51 @@ def test_sql_time_limit(app_client_shorter_time_limit): } -def test_custom_sql_time_limit(app_client): - response = app_client.get("/fixtures.json?sql=select+sleep(0.01)") - assert 200 == response.status - response = app_client.get("/fixtures.json?sql=select+sleep(0.01)&_timelimit=5") - assert 400 == response.status - assert "SQL Interrupted" == response.json["title"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql_time_limit(ds_client): + response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)") + assert response.status_code == 200 + response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)&_timelimit=5") + assert response.status_code == 400 + assert response.json()["title"] == "SQL Interrupted" -def test_invalid_custom_sql(app_client): - response = app_client.get("/fixtures.json?sql=.schema") - assert response.status == 400 - assert response.json["ok"] is False - assert "Statement must be a SELECT" == response.json["error"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_invalid_custom_sql(ds_client): + response = await ds_client.get("/fixtures.json?sql=.schema") + assert response.status_code == 400 + assert response.json()["ok"] is False + assert "Statement must be a SELECT" == response.json()["error"] -def test_row(app_client): - response = app_client.get("/fixtures/simple_primary_key/1.json?_shape=objects") - assert response.status == 200 - assert [{"id": "1", "content": "hello"}] == response.json["rows"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row(ds_client): + response = await ds_client.get("/fixtures/simple_primary_key/1.json?_shape=objects") + assert response.status_code == 200 + assert response.json()["rows"] == [{"id": "1", "content": "hello"}] -def test_row_strange_table_name(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row_strange_table_name(ds_client): + response = await ds_client.get( "/fixtures/table~2Fwith~2Fslashes~2Ecsv/3.json?_shape=objects" ) - assert response.status == 200 - assert [{"pk": "3", "content": "hey"}] == response.json["rows"] + assert response.status_code == 200 + assert response.json()["rows"] == [{"pk": "3", "content": "hey"}] -def test_row_foreign_key_tables(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_row_foreign_key_tables(ds_client): + response = await ds_client.get( "/fixtures/simple_primary_key/1.json?_extras=foreign_key_tables" ) - assert response.status == 200 - assert response.json["foreign_key_tables"] == [ + assert response.status_code == 200 + assert response.json()["foreign_key_tables"] == [ { "other_table": "foreign_key_references", "column": "id", @@ -762,47 +781,58 @@ def test_databases_json(app_client_two_attached_databases_one_immutable): assert False == fixtures_database["is_memory"] -def test_metadata_json(app_client): - response = app_client.get("/-/metadata.json") - assert METADATA == response.json +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_metadata_json(ds_client): + response = await ds_client.get("/-/metadata.json") + assert response.json() == METADATA -def test_threads_json(app_client): - response = app_client.get("/-/threads.json") +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_threads_json(ds_client): + response = await ds_client.get("/-/threads.json") expected_keys = {"threads", "num_threads"} if sys.version_info >= (3, 7, 0): expected_keys.update({"tasks", "num_tasks"}) - assert expected_keys == set(response.json.keys()) + assert set(response.json().keys()) == expected_keys -def test_plugins_json(app_client): - response = app_client.get("/-/plugins.json") - assert EXPECTED_PLUGINS == sorted(response.json, key=lambda p: p["name"]) +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_plugins_json(ds_client): + response = await ds_client.get("/-/plugins.json") + assert EXPECTED_PLUGINS == sorted(response.json(), key=lambda p: p["name"]) # Try with ?all=1 - response = app_client.get("/-/plugins.json?all=1") - names = {p["name"] for p in response.json} + response = await ds_client.get("/-/plugins.json?all=1") + names = {p["name"] for p in response.json()} assert names.issuperset(p["name"] for p in EXPECTED_PLUGINS) assert names.issuperset(DEFAULT_PLUGINS) -def test_versions_json(app_client): - response = app_client.get("/-/versions.json") - assert "python" in response.json - assert "3.0" == response.json.get("asgi") - assert "version" in response.json["python"] - assert "full" in response.json["python"] - assert "datasette" in response.json - assert "version" in response.json["datasette"] - assert response.json["datasette"]["version"] == __version__ - assert "sqlite" in response.json - assert "version" in response.json["sqlite"] - assert "fts_versions" in response.json["sqlite"] - assert "compile_options" in response.json["sqlite"] +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_versions_json(ds_client): + response = await ds_client.get("/-/versions.json") + data = response.json() + assert "python" in data + assert "3.0" == data.get("asgi") + assert "version" in data["python"] + assert "full" in data["python"] + assert "datasette" in data + assert "version" in data["datasette"] + assert data["datasette"]["version"] == __version__ + assert "sqlite" in data + assert "version" in data["sqlite"] + assert "fts_versions" in data["sqlite"] + assert "compile_options" in data["sqlite"] -def test_settings_json(app_client): - response = app_client.get("/-/settings.json") - assert { +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_settings_json(ds_client): + response = await ds_client.get("/-/settings.json") + assert response.json() == { "default_page_size": 50, "default_facet_size": 30, "facet_suggest_time_limit_ms": 50, @@ -825,9 +855,11 @@ def test_settings_json(app_client): "template_debug": False, "trace_debug": False, "base_url": "/", - } == response.json + } +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "path,expected_redirect", ( @@ -835,9 +867,9 @@ def test_settings_json(app_client): ("/-/config", "/-/settings"), ), ) -def test_config_redirects_to_settings(app_client, path, expected_redirect): - response = app_client.get(path) - assert response.status == 301 +async def test_config_redirects_to_settings(ds_client, path, expected_redirect): + response = await ds_client.get(path) + assert response.status_code == 301 assert response.headers["Location"] == expected_redirect @@ -846,6 +878,8 @@ test_json_columns_default_expected = [ ] +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "extra_args,expected", [ @@ -859,15 +893,15 @@ test_json_columns_default_expected = [ ), ], ) -def test_json_columns(app_client, extra_args, expected): +async def test_json_columns(ds_client, extra_args, expected): sql = """ select 1 as intval, "s" as strval, 0.5 as floatval, '{"foo": "bar"}' as jsonval """ path = "/fixtures.json?" + urllib.parse.urlencode({"sql": sql, "_shape": "array"}) path += extra_args - response = app_client.get(path) - assert expected == response.json + response = await ds_client.get(path) + assert response.json() == expected def test_config_cache_size(app_client_larger_cache_size): @@ -966,9 +1000,11 @@ def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): assert response.json["filtered_table_rows_count"] == 100 -def test_http_options_request(app_client): - response = app_client.request("/fixtures", method="OPTIONS") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_http_options_request(ds_client): + response = await ds_client.options("/fixtures") + assert response.status_code == 200 assert response.text == "ok" From 425ac4357ffb722a6ca86d08faba02ee38ad8689 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 14:18:40 -0800 Subject: [PATCH 010/603] Ported app_client to ds_client where possible in test_auth.py, refs #1959 --- datasette/app.py | 4 ++ datasette/utils/testing.py | 7 -- tests/plugins/my_plugin_2.py | 4 +- tests/test_auth.py | 132 ++++++++++++++++++++--------------- tests/test_messages.py | 3 +- tests/utils.py | 8 +++ 6 files changed, 91 insertions(+), 67 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index b770b469..04e26a46 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1718,6 +1718,10 @@ class DatasetteClient: self.ds = ds self.app = ds.app() + def actor_cookie(self, actor): + # Utility method, mainly for tests + return self.ds.sign({"a": actor}, "actor") + def _fix(self, path, avoid_path_rewrites=False): if not isinstance(path, PrefixedUrlString) and not avoid_path_rewrites: path = self.ds.urls.path(path) diff --git a/datasette/utils/testing.py b/datasette/utils/testing.py index 4f76a799..cabe2e5c 100644 --- a/datasette/utils/testing.py +++ b/datasette/utils/testing.py @@ -28,13 +28,6 @@ class TestResponse: def cookies(self): return dict(self.httpx_response.cookies) - def cookie_was_deleted(self, cookie): - return any( - h - for h in self.httpx_response.headers.get_list("set-cookie") - if h.startswith(f'{cookie}="";') - ) - @property def json(self): return json.loads(self.text) diff --git a/tests/plugins/my_plugin_2.py b/tests/plugins/my_plugin_2.py index cee80703..4f7bf08c 100644 --- a/tests/plugins/my_plugin_2.py +++ b/tests/plugins/my_plugin_2.py @@ -135,7 +135,9 @@ def prepare_jinja2_environment(env, datasette): @hookimpl def startup(datasette): async def inner(): - result = await datasette.get_database().execute("select 1 + 1") + # Run against _internal so tests that use the ds_client fixture + # (which has no databases yet on startup) do not fail: + result = await datasette.get_database("_internal").execute("select 1 + 1") datasette._startup_hook_calculation = result.first()[0] return inner diff --git a/tests/test_auth.py b/tests/test_auth.py index dd1b61e3..bc5c6a2b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,6 @@ from bs4 import BeautifulSoup as Soup from .fixtures import app_client +from .utils import cookie_was_deleted from click.testing import CliRunner from datasette.utils import baseconv from datasette.cli import cli @@ -7,46 +8,47 @@ import pytest import time -def test_auth_token(app_client): +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_auth_token(ds_client): """The /-/auth-token endpoint sets the correct cookie""" - assert app_client.ds._root_token is not None - path = f"/-/auth-token?token={app_client.ds._root_token}" - response = app_client.get( - path, - ) - assert 302 == response.status + assert ds_client.ds._root_token is not None + path = f"/-/auth-token?token={ds_client.ds._root_token}" + response = await ds_client.get(path) + assert response.status_code == 302 assert "/" == response.headers["Location"] - assert {"a": {"id": "root"}} == app_client.ds.unsign( + assert {"a": {"id": "root"}} == ds_client.ds.unsign( response.cookies["ds_actor"], "actor" ) # Check that a second with same token fails - assert app_client.ds._root_token is None - assert ( - 403 - == app_client.get( - path, - ).status - ) + assert ds_client.ds._root_token is None + assert (await ds_client.get(path)).status_code == 403 -def test_actor_cookie(app_client): +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_actor_cookie(ds_client): """A valid actor cookie sets request.scope['actor']""" - cookie = app_client.actor_cookie({"id": "test"}) - response = app_client.get("/", cookies={"ds_actor": cookie}) - assert {"id": "test"} == app_client.ds._last_request.scope["actor"] + cookie = ds_client.actor_cookie({"id": "test"}) + await ds_client.get("/", cookies={"ds_actor": cookie}) + assert ds_client.ds._last_request.scope["actor"] == {"id": "test"} -def test_actor_cookie_invalid(app_client): - cookie = app_client.actor_cookie({"id": "test"}) +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_actor_cookie_invalid(ds_client): + cookie = ds_client.actor_cookie({"id": "test"}) # Break the signature - response = app_client.get("/", cookies={"ds_actor": cookie[:-1] + "."}) - assert None == app_client.ds._last_request.scope["actor"] + await ds_client.get("/", cookies={"ds_actor": cookie[:-1] + "."}) + assert ds_client.ds._last_request.scope["actor"] is None # Break the cookie format - cookie = app_client.ds.sign({"b": {"id": "test"}}, "actor") - response = app_client.get("/", cookies={"ds_actor": cookie}) - assert None == app_client.ds._last_request.scope["actor"] + cookie = ds_client.ds.sign({"b": {"id": "test"}}, "actor") + await ds_client.get("/", cookies={"ds_actor": cookie}) + assert ds_client.ds._last_request.scope["actor"] is None +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "offset,expected", [ @@ -54,16 +56,17 @@ def test_actor_cookie_invalid(app_client): (-(24 * 60 * 60), None), ], ) -def test_actor_cookie_that_expires(app_client, offset, expected): +async def test_actor_cookie_that_expires(ds_client, offset, expected): expires_at = int(time.time()) + offset - cookie = app_client.ds.sign( + cookie = ds_client.ds.sign( {"a": {"id": "test"}, "e": baseconv.base62.encode(expires_at)}, "actor" ) - response = app_client.get("/", cookies={"ds_actor": cookie}) - assert expected == app_client.ds._last_request.scope["actor"] + response = await ds_client.get("/", cookies={"ds_actor": cookie}) + assert ds_client.ds._last_request.scope["actor"] == expected def test_logout(app_client): + # Keeping app_client for the moment because of csrftoken_from response = app_client.get( "/-/logout", cookies={"ds_actor": app_client.actor_cookie({"id": "test"})} ) @@ -88,18 +91,20 @@ def test_logout(app_client): cookies={"ds_actor": app_client.actor_cookie({"id": "test"})}, ) # The ds_actor cookie should have been unset - assert response4.cookie_was_deleted("ds_actor") + assert cookie_was_deleted(response4, "ds_actor") # Should also have set a message messages = app_client.ds.unsign(response4.cookies["ds_messages"], "messages") assert [["You are now logged out", 2]] == messages +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize("path", ["/", "/fixtures", "/fixtures/facetable"]) -def test_logout_button_in_navigation(app_client, path): - response = app_client.get( - path, cookies={"ds_actor": app_client.actor_cookie({"id": "test"})} +async def test_logout_button_in_navigation(ds_client, path): + response = await ds_client.get( + path, cookies={"ds_actor": ds_client.actor_cookie({"id": "test"})} ) - anon_response = app_client.get(path) + anon_response = await ds_client.get(path) for fragment in ( "test", '
', @@ -108,9 +113,11 @@ def test_logout_button_in_navigation(app_client, path): assert fragment not in anon_response.text +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize("path", ["/", "/fixtures", "/fixtures/facetable"]) -def test_no_logout_button_in_navigation_if_no_ds_actor_cookie(app_client, path): - response = app_client.get(path + "?_bot=1") +async def test_no_logout_button_in_navigation_if_no_ds_actor_cookie(ds_client, path): + response = await ds_client.get(path + "?_bot=1") assert "bot" in response.text assert '' not in response.text @@ -205,25 +212,33 @@ def test_auth_create_token( assert response3.json["actor"]["id"] == "test" -def test_auth_create_token_not_allowed_for_tokens(app_client): - ds_tok = app_client.ds.sign({"a": "test", "token": "dstok"}, "token") - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_auth_create_token_not_allowed_for_tokens(ds_client): + ds_tok = ds_client.ds.sign({"a": "test", "token": "dstok"}, "token") + response = await ds_client.get( "/-/create-token", headers={"Authorization": "Bearer dstok_{}".format(ds_tok)}, ) - assert response.status == 403 + assert response.status_code == 403 -def test_auth_create_token_not_allowed_if_allow_signed_tokens_off(app_client): - app_client.ds._settings["allow_signed_tokens"] = False +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_auth_create_token_not_allowed_if_allow_signed_tokens_off(ds_client): + ds_client.ds._settings["allow_signed_tokens"] = False try: - ds_actor = app_client.actor_cookie({"id": "test"}) - response = app_client.get("/-/create-token", cookies={"ds_actor": ds_actor}) - assert response.status == 403 + ds_actor = ds_client.actor_cookie({"id": "test"}) + response = await ds_client.get( + "/-/create-token", cookies={"ds_actor": ds_actor} + ) + assert response.status_code == 403 finally: - app_client.ds._settings["allow_signed_tokens"] = True + ds_client.ds._settings["allow_signed_tokens"] = True +@pytest.mark.ds_client +@pytest.mark.asyncio @pytest.mark.parametrize( "scenario,should_work", ( @@ -236,31 +251,32 @@ def test_auth_create_token_not_allowed_if_allow_signed_tokens_off(app_client): ("valid_expiring_token", True), ), ) -def test_auth_with_dstok_token(app_client, scenario, should_work): +async def test_auth_with_dstok_token(ds_client, scenario, should_work): token = None _time = int(time.time()) if scenario in ("valid_unlimited_token", "allow_signed_tokens_off"): - token = app_client.ds.sign({"a": "test", "t": _time}, "token") + token = ds_client.ds.sign({"a": "test", "t": _time}, "token") elif scenario == "valid_expiring_token": - token = app_client.ds.sign({"a": "test", "t": _time - 50, "d": 1000}, "token") + token = ds_client.ds.sign({"a": "test", "t": _time - 50, "d": 1000}, "token") elif scenario == "expired_token": - token = app_client.ds.sign({"a": "test", "t": _time - 2000, "d": 1000}, "token") + token = ds_client.ds.sign({"a": "test", "t": _time - 2000, "d": 1000}, "token") elif scenario == "no_timestamp": - token = app_client.ds.sign({"a": "test"}, "token") + token = ds_client.ds.sign({"a": "test"}, "token") elif scenario == "invalid_token": token = "invalid" if token: token = "dstok_{}".format(token) if scenario == "allow_signed_tokens_off": - app_client.ds._settings["allow_signed_tokens"] = False + ds_client.ds._settings["allow_signed_tokens"] = False headers = {} if token: headers["Authorization"] = "Bearer {}".format(token) - response = app_client.get("/-/actor.json", headers=headers) + response = await ds_client.get("/-/actor.json", headers=headers) try: if should_work: - assert response.json.keys() == {"actor"} - actor = response.json["actor"] + data = response.json() + assert data.keys() == {"actor"} + actor = data["actor"] expected_keys = {"id", "token"} if scenario != "valid_unlimited_token": expected_keys.add("token_expires") @@ -270,9 +286,9 @@ def test_auth_with_dstok_token(app_client, scenario, should_work): if scenario != "valid_unlimited_token": assert isinstance(actor["token_expires"], int) else: - assert response.json == {"actor": None} + assert response.json() == {"actor": None} finally: - app_client.ds._settings["allow_signed_tokens"] = True + ds_client.ds._settings["allow_signed_tokens"] = True @pytest.mark.parametrize("expires", (None, 1000, -1000)) diff --git a/tests/test_messages.py b/tests/test_messages.py index 3af5439a..6fbff066 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,4 +1,5 @@ from .fixtures import app_client +from .utils import cookie_was_deleted import pytest @@ -25,4 +26,4 @@ def test_messages_are_displayed_and_cleared(app_client): # Messages should be in that HTML assert "xmessagex" in response.text # Cookie should have been set that clears messages - assert response.cookie_was_deleted("ds_messages") + assert cookie_was_deleted(response, "ds_messages") diff --git a/tests/utils.py b/tests/utils.py index 191ead9b..84d5b1df 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,3 +30,11 @@ def inner_html(soup): def has_load_extension(): conn = sqlite3.connect(":memory:") return hasattr(conn, "enable_load_extension") + + +def cookie_was_deleted(response, cookie): + return any( + h + for h in response.headers.get_list("set-cookie") + if h.startswith(f'{cookie}="";') + ) From 3001eec66a7ec2ba91f5c0acd7cf3c0b79ab2c00 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 14:24:39 -0800 Subject: [PATCH 011/603] ds_client for test_csv.py and test_canned_queries.py, refs #1959 --- tests/test_canned_queries.py | 10 ++-- tests/test_csv.py | 91 +++++++++++++++++++++++------------- 2 files changed, 65 insertions(+), 36 deletions(-) diff --git a/tests/test_canned_queries.py b/tests/test_canned_queries.py index 976aa0db..9e85e0d7 100644 --- a/tests/test_canned_queries.py +++ b/tests/test_canned_queries.py @@ -73,16 +73,18 @@ def canned_write_immutable_client(): yield client -def test_canned_query_with_named_parameter(app_client): - response = app_client.get("/fixtures/neighborhood_search.json?text=town") - assert [ +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_canned_query_with_named_parameter(ds_client): + response = await ds_client.get("/fixtures/neighborhood_search.json?text=town") + assert response.json()["rows"] == [ ["Corktown", "Detroit", "MI"], ["Downtown", "Los Angeles", "CA"], ["Downtown", "Detroit", "MI"], ["Greektown", "Detroit", "MI"], ["Koreatown", "Los Angeles", "CA"], ["Mexicantown", "Detroit", "MI"], - ] == response.json["rows"] + ] def test_insert(canned_write_client): diff --git a/tests/test_csv.py b/tests/test_csv.py index 7fc25a09..63184126 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -1,4 +1,5 @@ from bs4 import BeautifulSoup as Soup +import pytest from .fixtures import ( # noqa app_client, app_client_csv_max_mb_one, @@ -53,9 +54,11 @@ pk,foreign_key_with_label,foreign_key_with_label_label,foreign_key_with_blank_la ) -def test_table_csv(app_client): - response = app_client.get("/fixtures/simple_primary_key.csv?_oh=1") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv(ds_client): + response = await ds_client.get("/fixtures/simple_primary_key.csv?_oh=1") + assert response.status_code == 200 assert not response.headers.get("Access-Control-Allow-Origin") assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == EXPECTED_TABLE_CSV @@ -67,31 +70,39 @@ def test_table_csv_cors_headers(app_client_with_cors): assert response.headers["Access-Control-Allow-Origin"] == "*" -def test_table_csv_no_header(app_client): - response = app_client.get("/fixtures/simple_primary_key.csv?_header=off") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_no_header(ds_client): + response = await ds_client.get("/fixtures/simple_primary_key.csv?_header=off") + assert response.status_code == 200 assert not response.headers.get("Access-Control-Allow-Origin") assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == EXPECTED_TABLE_CSV.split("\r\n", 1)[1] -def test_table_csv_with_labels(app_client): - response = app_client.get("/fixtures/facetable.csv?_labels=1") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_with_labels(ds_client): + response = await ds_client.get("/fixtures/facetable.csv?_labels=1") + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == EXPECTED_TABLE_WITH_LABELS_CSV -def test_table_csv_with_nullable_labels(app_client): - response = app_client.get("/fixtures/foreign_key_references.csv?_labels=1") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_with_nullable_labels(ds_client): + response = await ds_client.get("/fixtures/foreign_key_references.csv?_labels=1") + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == EXPECTED_TABLE_WITH_NULLABLE_LABELS_CSV -def test_table_csv_blob_columns(app_client): - response = app_client.get("/fixtures/binary_data.csv") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_blob_columns(ds_client): + response = await ds_client.get("/fixtures/binary_data.csv") + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == ( "rowid,data\r\n" @@ -101,9 +112,13 @@ def test_table_csv_blob_columns(app_client): ) -def test_custom_sql_csv_blob_columns(app_client): - response = app_client.get("/fixtures.csv?sql=select+rowid,+data+from+binary_data") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql_csv_blob_columns(ds_client): + response = await ds_client.get( + "/fixtures.csv?sql=select+rowid,+data+from+binary_data" + ) + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == ( "rowid,data\r\n" @@ -113,18 +128,22 @@ def test_custom_sql_csv_blob_columns(app_client): ) -def test_custom_sql_csv(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_custom_sql_csv(ds_client): + response = await ds_client.get( "/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2" ) - assert response.status == 200 + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == EXPECTED_CUSTOM_CSV -def test_table_csv_download(app_client): - response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_download(ds_client): + response = await ds_client.get("/fixtures/simple_primary_key.csv?_dl=1") + assert response.status_code == 200 assert response.headers["content-type"] == "text/csv; charset=utf-8" assert ( response.headers["content-disposition"] @@ -132,11 +151,13 @@ def test_table_csv_download(app_client): ) -def test_csv_with_non_ascii_characters(app_client): - response = app_client.get( +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_csv_with_non_ascii_characters(ds_client): + response = await ds_client.get( "/fixtures.csv?sql=select%0D%0A++%27%F0%9D%90%9C%F0%9D%90%A2%F0%9D%90%AD%F0%9D%90%A2%F0%9D%90%9E%F0%9D%90%AC%27+as+text%2C%0D%0A++1+as+number%0D%0Aunion%0D%0Aselect%0D%0A++%27bob%27+as+text%2C%0D%0A++2+as+number%0D%0Aorder+by%0D%0A++number" ) - assert response.status == 200 + assert response.status_code == 200 assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == "text,number\r\nšœš¢š­š¢šžš¬,1\r\nbob,2\r\n" @@ -155,13 +176,19 @@ def test_max_csv_mb(app_client_csv_max_mb_one): assert last_line.startswith(b"CSV contains more than") -def test_table_csv_stream(app_client): +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_table_csv_stream(ds_client): # Without _stream should return header + 100 rows: - response = app_client.get("/fixtures/compound_three_primary_keys.csv?_size=max") - assert len([b for b in response.body.split(b"\r\n") if b]) == 101 + response = await ds_client.get( + "/fixtures/compound_three_primary_keys.csv?_size=max" + ) + assert len([b for b in response.content.split(b"\r\n") if b]) == 101 # With _stream=1 should return header + 1001 rows - response = app_client.get("/fixtures/compound_three_primary_keys.csv?_stream=1") - assert len([b for b in response.body.split(b"\r\n") if b]) == 1002 + response = await ds_client.get( + "/fixtures/compound_three_primary_keys.csv?_stream=1" + ) + assert len([b for b in response.content.split(b"\r\n") if b]) == 1002 def test_csv_trace(app_client_with_trace): From 95900b9d02c01de46dc510693ab3b316988bf64c Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 15 Dec 2022 14:44:30 -0800 Subject: [PATCH 012/603] Port app_client to ds_client for most of test_html.py, refs #1959 --- datasette/utils/testing.py | 9 + tests/test_html.py | 405 ++++++++++++++++++++++--------------- 2 files changed, 252 insertions(+), 162 deletions(-) diff --git a/datasette/utils/testing.py b/datasette/utils/testing.py index cabe2e5c..d4990784 100644 --- a/datasette/utils/testing.py +++ b/datasette/utils/testing.py @@ -16,6 +16,11 @@ class TestResponse: def status(self): return self.httpx_response.status_code + # Supports both for test-writing convenience + @property + def status_code(self): + return self.status + @property def headers(self): return self.httpx_response.headers @@ -24,6 +29,10 @@ class TestResponse: def body(self): return self.httpx_response.content + @property + def content(self): + return self.body + @property def cookies(self): return dict(self.httpx_response.cookies) diff --git a/tests/test_html.py b/tests/test_html.py index 83086cd4..3d4f2ed7 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -18,9 +18,9 @@ import urllib.parse def test_homepage(app_client_two_attached_databases): response = app_client_two_attached_databases.get("/") - assert response.status == 200 + assert response.status_code == 200 assert "text/html; charset=utf-8" == response.headers["content-type"] - soup = Soup(response.body, "html.parser") + soup = Soup(response.content, "html.parser") assert "Datasette Fixtures" == soup.find("h1").text assert ( "An example SQLite database demonstrating Datasette. Sign in as root user" @@ -48,30 +48,38 @@ def test_homepage(app_client_two_attached_databases): ] == table_links -def test_http_head(app_client): - response = app_client.get("/", method="HEAD") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_http_head(ds_client): + response = await ds_client.head("/") + assert response.status_code == 200 -def test_homepage_options(app_client): - response = app_client.get("/", method="OPTIONS") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_homepage_options(ds_client): + response = await ds_client.options("/") + assert response.status_code == 200 assert response.text == "ok" -def test_favicon(app_client): - response = app_client.get("/favicon.ico") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_favicon(ds_client): + response = await ds_client.get("/favicon.ico") + assert response.status_code == 200 assert response.headers["cache-control"] == "max-age=3600, immutable, public" assert int(response.headers["content-length"]) > 100 assert response.headers["content-type"] == "image/png" -def test_static(app_client): - response = app_client.get("/-/static/app2.css") - assert response.status == 404 - response = app_client.get("/-/static/app.css") - assert response.status == 200 +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_static(ds_client): + response = await ds_client.get("/-/static/app2.css") + assert response.status_code == 404 + response = await ds_client.get("/-/static/app.css") + assert response.status_code == 200 assert "text/css" == response.headers["content-type"] @@ -80,29 +88,31 @@ def test_static_mounts(): static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))] ) as client: response = client.get("/custom-static/test_html.py") - assert response.status == 200 + assert response.status_code == 200 response = client.get("/custom-static/not_exists.py") - assert response.status == 404 + assert response.status_code == 404 response = client.get("/custom-static/../LICENSE") - assert response.status == 404 + assert response.status_code == 404 def test_memory_database_page(): with make_app_client(memory=True) as client: response = client.get("/_memory") - assert response.status == 200 + assert response.status_code == 200 def test_not_allowed_methods(): with make_app_client(memory=True) as client: for method in ("post", "put", "patch", "delete"): response = client.request(path="/_memory", method=method.upper()) - assert response.status == 405 + assert response.status_code == 405 -def test_database_page(app_client): - response = app_client.get("/fixtures") - soup = Soup(response.body, "html.parser") +@pytest.mark.ds_client +@pytest.mark.asyncio +async def test_database_page(ds_client): + response = await ds_client.get("/fixtures") + soup = Soup(response.text, "html.parser") # Should have a " - - async def stream_fn(r): - nonlocal data, trace - limited_writer = LimitedWriter(r, self.ds.setting("max_csv_mb")) - if trace: - await limited_writer.write(preamble) - writer = csv.writer(EscapeHtmlWriter(limited_writer)) - else: - writer = csv.writer(limited_writer) - first = True - next = None - while first or (next and stream): - try: - kwargs = {} - if next: - kwargs["_next"] = next - if not first: - data, _, _ = await self.data(request, **kwargs) - if first: - if request.args.get("_header") != "off": - await writer.writerow(headings) - first = False - next = data.get("next") - for row in data["rows"]: - if any(isinstance(r, bytes) for r in row): - new_row = [] - for column, cell in zip(headings, row): - if isinstance(cell, bytes): - # If this is a table page, use .urls.row_blob() - if data.get("table"): - pks = data.get("primary_keys") or [] - cell = self.ds.absolute_url( - request, - self.ds.urls.row_blob( - database, - data["table"], - path_from_row_pks(row, pks, not pks), - column, - ), - ) - else: - # Otherwise generate URL for this query - url = self.ds.absolute_url( - request, - path_with_format( - request=request, - format="blob", - extra_qs={ - "_blob_column": column, - "_blob_hash": hashlib.sha256( - cell - ).hexdigest(), - }, - replace_format="csv", - ), - ) - cell = url.replace("&_nocount=1", "").replace( - "&_nofacet=1", "" - ) - new_row.append(cell) - row = new_row - if not expanded_columns: - # Simple path - await writer.writerow(row) - else: - # Look for {"value": "label": } dicts and expand - new_row = [] - for heading, cell in zip(data["columns"], row): - if heading in expanded_columns: - if cell is None: - new_row.extend(("", "")) - else: - assert isinstance(cell, dict) - new_row.append(cell["value"]) - new_row.append(cell["label"]) - else: - new_row.append(cell) - await writer.writerow(new_row) - except Exception as e: - sys.stderr.write("Caught this error: {}\n".format(e)) - sys.stderr.flush() - await r.write(str(e)) - return - await limited_writer.write(postamble) - - headers = {} - if self.ds.cors: - add_cors_headers(headers) - if request.args.get("_dl", None): - if not trace: - content_type = "text/csv; charset=utf-8" - disposition = 'attachment; filename="{}.csv"'.format( - request.url_vars.get("table", database) - ) - headers["content-disposition"] = disposition - - return AsgiStream(stream_fn, headers=headers, content_type=content_type) + return await stream_csv(self.ds, self.data, request, database) async def get(self, request): db = await self.ds.resolve_database(request) @@ -518,7 +350,7 @@ class DataView(BaseView): }, } if "metadata" not in context: - context["metadata"] = self.ds.metadata + context["metadata"] = self.ds.metadata() r = await self.render(templates, request=request, context=context) if status_code is not None: r.status = status_code @@ -546,3 +378,169 @@ class DataView(BaseView): def _error(messages, status=400): return Response.json({"ok": False, "errors": messages}, status=status) + + +async def stream_csv(datasette, fetch_data, request, database): + kwargs = {} + stream = request.args.get("_stream") + # Do not calculate facets or counts: + extra_parameters = [ + "{}=1".format(key) + for key in ("_nofacet", "_nocount") + if not request.args.get(key) + ] + if extra_parameters: + # Replace request object with a new one with modified scope + if not request.query_string: + new_query_string = "&".join(extra_parameters) + else: + new_query_string = request.query_string + "&" + "&".join(extra_parameters) + new_scope = dict(request.scope, query_string=new_query_string.encode("latin-1")) + receive = request.receive + request = Request(new_scope, receive) + if stream: + # Some quick soundness checks + if not datasette.setting("allow_csv_stream"): + raise BadRequest("CSV streaming is disabled") + if request.args.get("_next"): + raise BadRequest("_next not allowed for CSV streaming") + kwargs["_size"] = "max" + # Fetch the first page + try: + response_or_template_contexts = await fetch_data(request) + if isinstance(response_or_template_contexts, Response): + return response_or_template_contexts + elif len(response_or_template_contexts) == 4: + data, _, _, _ = response_or_template_contexts + else: + data, _, _ = response_or_template_contexts + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + + except sqlite3.OperationalError as e: + raise DatasetteError(str(e)) + + except DatasetteError: + raise + + # Convert rows and columns to CSV + headings = data["columns"] + # if there are expanded_columns we need to add additional headings + expanded_columns = set(data.get("expanded_columns") or []) + if expanded_columns: + headings = [] + for column in data["columns"]: + headings.append(column) + if column in expanded_columns: + headings.append(f"{column}_label") + + content_type = "text/plain; charset=utf-8" + preamble = "" + postamble = "" + + trace = request.args.get("_trace") + if trace: + content_type = "text/html; charset=utf-8" + preamble = ( + "CSV debug" + '" + + async def stream_fn(r): + nonlocal data, trace + print("max_csv_mb", datasette.setting("max_csv_mb")) + limited_writer = LimitedWriter(r, datasette.setting("max_csv_mb")) + if trace: + await limited_writer.write(preamble) + writer = csv.writer(EscapeHtmlWriter(limited_writer)) + else: + writer = csv.writer(limited_writer) + first = True + next = None + while first or (next and stream): + try: + kwargs = {} + if next: + kwargs["_next"] = next + if not first: + data, _, _ = await fetch_data(request, **kwargs) + if first: + if request.args.get("_header") != "off": + await writer.writerow(headings) + first = False + next = data.get("next") + for row in data["rows"]: + if any(isinstance(r, bytes) for r in row): + new_row = [] + for column, cell in zip(headings, row): + if isinstance(cell, bytes): + # If this is a table page, use .urls.row_blob() + if data.get("table"): + pks = data.get("primary_keys") or [] + cell = datasette.absolute_url( + request, + datasette.urls.row_blob( + database, + data["table"], + path_from_row_pks(row, pks, not pks), + column, + ), + ) + else: + # Otherwise generate URL for this query + url = datasette.absolute_url( + request, + path_with_format( + request=request, + format="blob", + extra_qs={ + "_blob_column": column, + "_blob_hash": hashlib.sha256( + cell + ).hexdigest(), + }, + replace_format="csv", + ), + ) + cell = url.replace("&_nocount=1", "").replace( + "&_nofacet=1", "" + ) + new_row.append(cell) + row = new_row + if not expanded_columns: + # Simple path + await writer.writerow(row) + else: + # Look for {"value": "label": } dicts and expand + new_row = [] + for heading, cell in zip(data["columns"], row): + if heading in expanded_columns: + if cell is None: + new_row.extend(("", "")) + else: + assert isinstance(cell, dict) + new_row.append(cell["value"]) + new_row.append(cell["label"]) + else: + new_row.append(cell) + await writer.writerow(new_row) + except Exception as e: + sys.stderr.write("Caught this error: {}\n".format(e)) + sys.stderr.flush() + await r.write(str(e)) + return + await limited_writer.write(postamble) + + headers = {} + if datasette.cors: + add_cors_headers(headers) + if request.args.get("_dl", None): + if not trace: + content_type = "text/csv; charset=utf-8" + disposition = 'attachment; filename="{}.csv"'.format( + request.url_vars.get("table", database) + ) + headers["content-disposition"] = disposition + + return AsgiStream(stream_fn, headers=headers, content_type=content_type) diff --git a/datasette/views/database.py b/datasette/views/database.py index 8d289105..dda82510 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -223,6 +223,7 @@ class QueryView(DataView): _size=None, named_parameters=None, write=False, + default_labels=None, ): db = await self.ds.resolve_database(request) database = db.name diff --git a/datasette/views/table.py b/datasette/views/table.py index 49f6052a..0a6203f2 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -1,19 +1,23 @@ import asyncio import itertools import json +import urllib +from asyncinject import Registry import markupsafe from datasette.plugins import pm from datasette.database import QueryInterrupted from datasette import tracer +from datasette.renderer import json_renderer from datasette.utils import ( + add_cors_headers, await_me_maybe, + call_with_supported_arguments, CustomRow, append_querystring, compound_keys_after_sql, format_bytes, - tilde_decode, tilde_encode, escape_sqlite, filters_should_redirect, @@ -21,17 +25,20 @@ from datasette.utils import ( is_url, path_from_row_pks, path_with_added_args, + path_with_format, path_with_removed_args, path_with_replaced_args, to_css_class, truncate_url, urlsafe_components, value_as_boolean, + InvalidSql, + sqlite3, ) from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response from datasette.filters import Filters import sqlite_utils -from .base import BaseView, DataView, DatasetteError, ureg, _error +from .base import BaseView, DataView, DatasetteError, ureg, _error, stream_csv from .database import QueryView LINK_WITH_LABEL = ( @@ -69,812 +76,56 @@ class Row: return json.dumps(d, default=repr, indent=2) -class TableView(DataView): - name = "table" +async def _gather_parallel(*args): + return await asyncio.gather(*args) - async def sortable_columns_for_table(self, database_name, table_name, use_rowid): - db = self.ds.databases[database_name] - table_metadata = self.ds.table_metadata(database_name, table_name) - if "sortable_columns" in table_metadata: - sortable_columns = set(table_metadata["sortable_columns"]) - else: - sortable_columns = set(await db.table_columns(table_name)) - if use_rowid: - sortable_columns.add("rowid") - return sortable_columns - async def expandable_columns(self, database_name, table_name): - # Returns list of (fk_dict, label_column-or-None) pairs for that table - expandables = [] - db = self.ds.databases[database_name] - for fk in await db.foreign_keys_for_table(table_name): - label_column = await db.label_column_for_table(fk["other_table"]) - expandables.append((fk, label_column)) - return expandables +async def _gather_sequential(*args): + results = [] + for fn in args: + results.append(await fn) + return results - async def post(self, request): - from datasette.app import TableNotFound - try: - resolved = await self.ds.resolve_table(request) - except TableNotFound as e: - # Was this actually a canned query? - canned_query = await self.ds.get_canned_query( - e.database_name, e.table, request.actor - ) - if canned_query: - # Handle POST to a canned query - return await QueryView(self.ds).data( +def _redirect(datasette, request, path, forward_querystring=True, remove_args=None): + if request.query_string and "?" not in path and forward_querystring: + path = f"{path}?{request.query_string}" + if remove_args: + path = path_with_removed_args(request, remove_args, path=path) + r = Response.redirect(path) + r.headers["Link"] = f"<{path}>; rel=preload" + if datasette.cors: + add_cors_headers(r.headers) + return r + + +async def _redirect_if_needed(datasette, request, resolved): + # Handle ?_filter_column + redirect_params = filters_should_redirect(request.args) + if redirect_params: + return _redirect( + datasette, + request, + datasette.urls.path(path_with_added_args(request, redirect_params)), + forward_querystring=False, + ) + + # If ?_sort_by_desc=on (from checkbox) redirect to _sort_desc=(_sort) + if "_sort_by_desc" in request.args: + return _redirect( + datasette, + request, + datasette.urls.path( + path_with_added_args( request, - canned_query["sql"], - metadata=canned_query, - editable=False, - canned_query=e.table, - named_parameters=canned_query.get("params"), - write=bool(canned_query.get("write")), + { + "_sort_desc": request.args.get("_sort"), + "_sort_by_desc": None, + "_sort": None, + }, ) - - # Handle POST to a table - return await self.table_post( - request, resolved.db, resolved.db.name, resolved.table - ) - - async def table_post(self, request, db, database_name, table_name): - # Must have insert-row permission - if not await self.ds.permission_allowed( - request.actor, "insert-row", resource=(database_name, table_name) - ): - raise Forbidden("Permission denied") - if request.headers.get("content-type") != "application/json": - # TODO: handle form-encoded data - raise BadRequest("Must send JSON data") - data = json.loads(await request.post_body()) - if "insert" not in data: - raise BadRequest('Must send a "insert" key containing a dictionary') - row = data["insert"] - if not isinstance(row, dict): - raise BadRequest("insert must be a dictionary") - # Verify all columns exist - columns = await db.table_columns(table_name) - pks = await db.primary_keys(table_name) - for key in row: - if key not in columns: - raise BadRequest("Column not found: {}".format(key)) - if key in pks: - raise BadRequest( - "Cannot insert into primary key column: {}".format(key) - ) - # Perform the insert - sql = "INSERT INTO [{table}] ({columns}) VALUES ({values})".format( - table=escape_sqlite(table_name), - columns=", ".join(escape_sqlite(c) for c in row), - values=", ".join("?" for c in row), - ) - cursor = await db.execute_write(sql, list(row.values())) - # Return the new row - rowid = cursor.lastrowid - new_row = ( - await db.execute( - "SELECT * FROM [{table}] WHERE rowid = ?".format( - table=escape_sqlite(table_name) - ), - [rowid], - ) - ).first() - return Response.json( - { - "inserted_row": dict(new_row), - }, - status=201, - ) - - async def columns_to_select(self, table_columns, pks, request): - columns = list(table_columns) - if "_col" in request.args: - columns = list(pks) - _cols = request.args.getlist("_col") - bad_columns = [column for column in _cols if column not in table_columns] - if bad_columns: - raise DatasetteError( - "_col={} - invalid columns".format(", ".join(bad_columns)), - status=400, - ) - # De-duplicate maintaining order: - columns.extend(dict.fromkeys(_cols)) - if "_nocol" in request.args: - # Return all columns EXCEPT these - bad_columns = [ - column - for column in request.args.getlist("_nocol") - if (column not in table_columns) or (column in pks) - ] - if bad_columns: - raise DatasetteError( - "_nocol={} - invalid columns".format(", ".join(bad_columns)), - status=400, - ) - tmp_columns = [ - column - for column in columns - if column not in request.args.getlist("_nocol") - ] - columns = tmp_columns - return columns - - async def data( - self, - request, - default_labels=False, - _next=None, - _size=None, - ): - with tracer.trace_child_tasks(): - return await self._data_traced(request, default_labels, _next, _size) - - async def _data_traced( - self, - request, - default_labels=False, - _next=None, - _size=None, - ): - from datasette.app import TableNotFound - - try: - resolved = await self.ds.resolve_table(request) - except TableNotFound as e: - # Was this actually a canned query? - canned_query = await self.ds.get_canned_query( - e.database_name, e.table, request.actor - ) - # If this is a canned query, not a table, then dispatch to QueryView instead - if canned_query: - return await QueryView(self.ds).data( - request, - canned_query["sql"], - metadata=canned_query, - editable=False, - canned_query=e.table, - named_parameters=canned_query.get("params"), - write=bool(canned_query.get("write")), - ) - else: - raise - - table_name = resolved.table - db = resolved.db - database_name = db.name - - # For performance profiling purposes, ?_noparallel=1 turns off asyncio.gather - async def _gather_parallel(*args): - return await asyncio.gather(*args) - - async def _gather_sequential(*args): - results = [] - for fn in args: - results.append(await fn) - return results - - gather = ( - _gather_sequential if request.args.get("_noparallel") else _gather_parallel - ) - - is_view, table_exists = map( - bool, - await gather( - db.get_view_definition(table_name), db.table_exists(table_name) - ), - ) - - # If table or view not found, return 404 - if not is_view and not table_exists: - raise NotFound(f"Table not found: {table_name}") - - # Ensure user has permission to view this table - visible, private = await self.ds.check_visibility( - request.actor, - permissions=[ - ("view-table", (database_name, table_name)), - ("view-database", database_name), - "view-instance", - ], - ) - if not visible: - raise Forbidden("You do not have permission to view this table") - - # Handle ?_filter_column and redirect, if present - redirect_params = filters_should_redirect(request.args) - if redirect_params: - return self.redirect( - request, - self.ds.urls.path(path_with_added_args(request, redirect_params)), - forward_querystring=False, - ) - - # If ?_sort_by_desc=on (from checkbox) redirect to _sort_desc=(_sort) - if "_sort_by_desc" in request.args: - return self.redirect( - request, - self.ds.urls.path( - path_with_added_args( - request, - { - "_sort_desc": request.args.get("_sort"), - "_sort_by_desc": None, - "_sort": None, - }, - ) - ), - forward_querystring=False, - ) - - # Introspect columns and primary keys for table - pks = await db.primary_keys(table_name) - table_columns = await db.table_columns(table_name) - - # Take ?_col= and ?_nocol= into account - specified_columns = await self.columns_to_select(table_columns, pks, request) - select_specified_columns = ", ".join( - escape_sqlite(t) for t in specified_columns - ) - select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns) - - # rowid tables (no specified primary key) need a different SELECT - use_rowid = not pks and not is_view - if use_rowid: - select_specified_columns = f"rowid, {select_specified_columns}" - select_all_columns = f"rowid, {select_all_columns}" - order_by = "rowid" - order_by_pks = "rowid" - else: - order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) - order_by = order_by_pks - - if is_view: - order_by = "" - - nocount = request.args.get("_nocount") - nofacet = request.args.get("_nofacet") - nosuggest = request.args.get("_nosuggest") - - if request.args.get("_shape") in ("array", "object"): - nocount = True - nofacet = True - - table_metadata = self.ds.table_metadata(database_name, table_name) - units = table_metadata.get("units", {}) - - # Arguments that start with _ and don't contain a __ are - # special - things like ?_search= - and should not be - # treated as filters. - filter_args = [] - for key in request.args: - if not (key.startswith("_") and "__" not in key): - for v in request.args.getlist(key): - filter_args.append((key, v)) - - # Build where clauses from query string arguments - filters = Filters(sorted(filter_args), units, ureg) - where_clauses, params = filters.build_where_clauses(table_name) - - # Execute filters_from_request plugin hooks - including the default - # ones that live in datasette/filters.py - extra_context_from_filters = {} - extra_human_descriptions = [] - - for hook in pm.hook.filters_from_request( - request=request, - table=table_name, - database=database_name, - datasette=self.ds, - ): - filter_arguments = await await_me_maybe(hook) - if filter_arguments: - where_clauses.extend(filter_arguments.where_clauses) - params.update(filter_arguments.params) - extra_human_descriptions.extend(filter_arguments.human_descriptions) - extra_context_from_filters.update(filter_arguments.extra_context) - - # Deal with custom sort orders - sortable_columns = await self.sortable_columns_for_table( - database_name, table_name, use_rowid - ) - sort = request.args.get("_sort") - sort_desc = request.args.get("_sort_desc") - - if not sort and not sort_desc: - sort = table_metadata.get("sort") - sort_desc = table_metadata.get("sort_desc") - - if sort and sort_desc: - raise DatasetteError( - "Cannot use _sort and _sort_desc at the same time", status=400 - ) - - if sort: - if sort not in sortable_columns: - raise DatasetteError(f"Cannot sort table by {sort}", status=400) - - order_by = escape_sqlite(sort) - - if sort_desc: - if sort_desc not in sortable_columns: - raise DatasetteError(f"Cannot sort table by {sort_desc}", status=400) - - order_by = f"{escape_sqlite(sort_desc)} desc" - - from_sql = "from {table_name} {where}".format( - table_name=escape_sqlite(table_name), - where=("where {} ".format(" and ".join(where_clauses))) - if where_clauses - else "", - ) - # Copy of params so we can mutate them later: - from_sql_params = dict(**params) - - count_sql = f"select count(*) {from_sql}" - - # Handle pagination driven by ?_next= - _next = _next or request.args.get("_next") - offset = "" - if _next: - sort_value = None - if is_view: - # _next is an offset - offset = f" offset {int(_next)}" - else: - components = urlsafe_components(_next) - # If a sort order is applied and there are multiple components, - # the first of these is the sort value - if (sort or sort_desc) and (len(components) > 1): - sort_value = components[0] - # Special case for if non-urlencoded first token was $null - if _next.split(",")[0] == "$null": - sort_value = None - components = components[1:] - - # Figure out the SQL for next-based-on-primary-key first - next_by_pk_clauses = [] - if use_rowid: - next_by_pk_clauses.append(f"rowid > :p{len(params)}") - params[f"p{len(params)}"] = components[0] - else: - # Apply the tie-breaker based on primary keys - if len(components) == len(pks): - param_len = len(params) - next_by_pk_clauses.append( - compound_keys_after_sql(pks, param_len) - ) - for i, pk_value in enumerate(components): - params[f"p{param_len + i}"] = pk_value - - # Now add the sort SQL, which may incorporate next_by_pk_clauses - if sort or sort_desc: - if sort_value is None: - if sort_desc: - # Just items where column is null ordered by pk - where_clauses.append( - "({column} is null and {next_clauses})".format( - column=escape_sqlite(sort_desc), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - "({column} is not null or ({column} is null and {next_clauses}))".format( - column=escape_sqlite(sort), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - "({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( - column=escape_sqlite(sort or sort_desc), - op=">" if sort else "<", - p=len(params), - extra_desc_only="" - if sort - else " or {column2} is null".format( - column2=escape_sqlite(sort or sort_desc) - ), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - params[f"p{len(params)}"] = sort_value - order_by = f"{order_by}, {order_by_pks}" - else: - where_clauses.extend(next_by_pk_clauses) - - where_clause = "" - if where_clauses: - where_clause = f"where {' and '.join(where_clauses)} " - - if order_by: - order_by = f"order by {order_by}" - - extra_args = {} - # Handle ?_size=500 - page_size = _size or request.args.get("_size") or table_metadata.get("size") - if page_size: - if page_size == "max": - page_size = self.ds.max_returned_rows - try: - page_size = int(page_size) - if page_size < 0: - raise ValueError - - except ValueError: - raise BadRequest("_size must be a positive integer") - - if page_size > self.ds.max_returned_rows: - raise BadRequest(f"_size must be <= {self.ds.max_returned_rows}") - - extra_args["page_size"] = page_size - else: - page_size = self.ds.page_size - - # Facets are calculated against SQL without order by or limit - sql_no_order_no_limit = ( - "select {select_all_columns} from {table_name} {where}".format( - select_all_columns=select_all_columns, - table_name=escape_sqlite(table_name), - where=where_clause, - ) - ) - - # This is the SQL that populates the main table on the page - sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format( - select_specified_columns=select_specified_columns, - table_name=escape_sqlite(table_name), - where=where_clause, - order_by=order_by, - page_size=page_size + 1, - offset=offset, - ) - - if request.args.get("_timelimit"): - extra_args["custom_time_limit"] = int(request.args.get("_timelimit")) - - # Execute the main query! - results = await db.execute(sql, params, truncate=True, **extra_args) - - # Calculate the total count for this query - count = None - if ( - not db.is_mutable - and self.ds.inspect_data - and count_sql == f"select count(*) from {table_name} " - ): - # We can use a previously cached table row count - try: - count = self.ds.inspect_data[database_name]["tables"][table_name][ - "count" - ] - except KeyError: - pass - - # Otherwise run a select count(*) ... - if count_sql and count is None and not nocount: - try: - count_rows = list(await db.execute(count_sql, from_sql_params)) - count = count_rows[0][0] - except QueryInterrupted: - pass - - # Faceting - if not self.ds.setting("allow_facet") and any( - arg.startswith("_facet") for arg in request.args - ): - raise BadRequest("_facet= is not allowed") - - # pylint: disable=no-member - facet_classes = list( - itertools.chain.from_iterable(pm.hook.register_facet_classes()) - ) - facet_results = {} - facets_timed_out = [] - facet_instances = [] - for klass in facet_classes: - facet_instances.append( - klass( - self.ds, - request, - database_name, - sql=sql_no_order_no_limit, - params=params, - table=table_name, - metadata=table_metadata, - row_count=count, - ) - ) - - async def execute_facets(): - if not nofacet: - # Run them in parallel - facet_awaitables = [facet.facet_results() for facet in facet_instances] - facet_awaitable_results = await gather(*facet_awaitables) - for ( - instance_facet_results, - instance_facets_timed_out, - ) in facet_awaitable_results: - for facet_info in instance_facet_results: - base_key = facet_info["name"] - key = base_key - i = 1 - while key in facet_results: - i += 1 - key = f"{base_key}_{i}" - facet_results[key] = facet_info - facets_timed_out.extend(instance_facets_timed_out) - - suggested_facets = [] - - async def execute_suggested_facets(): - # Calculate suggested facets - if ( - self.ds.setting("suggest_facets") - and self.ds.setting("allow_facet") - and not _next - and not nofacet - and not nosuggest - ): - # Run them in parallel - facet_suggest_awaitables = [ - facet.suggest() for facet in facet_instances - ] - for suggest_result in await gather(*facet_suggest_awaitables): - suggested_facets.extend(suggest_result) - - await gather(execute_facets(), execute_suggested_facets()) - - # Figure out columns and rows for the query - columns = [r[0] for r in results.description] - rows = list(results.rows) - - # Expand labeled columns if requested - expanded_columns = [] - expandable_columns = await self.expandable_columns(database_name, table_name) - columns_to_expand = None - try: - all_labels = value_as_boolean(request.args.get("_labels", "")) - except ValueError: - all_labels = default_labels - # Check for explicit _label= - if "_label" in request.args: - columns_to_expand = request.args.getlist("_label") - if columns_to_expand is None and all_labels: - # expand all columns with foreign keys - columns_to_expand = [fk["column"] for fk, _ in expandable_columns] - - if columns_to_expand: - expanded_labels = {} - for fk, _ in expandable_columns: - column = fk["column"] - if column not in columns_to_expand: - continue - if column not in columns: - continue - expanded_columns.append(column) - # Gather the values - column_index = columns.index(column) - values = [row[column_index] for row in rows] - # Expand them - expanded_labels.update( - await self.ds.expand_foreign_keys( - database_name, table_name, column, values - ) - ) - if expanded_labels: - # Rewrite the rows - new_rows = [] - for row in rows: - new_row = CustomRow(columns) - for column in row.keys(): - value = row[column] - if (column, value) in expanded_labels and value is not None: - new_row[column] = { - "value": value, - "label": expanded_labels[(column, value)], - } - else: - new_row[column] = value - new_rows.append(new_row) - rows = new_rows - - # Pagination next link - next_value = None - next_url = None - if 0 < page_size < len(rows): - if is_view: - next_value = int(_next or 0) + page_size - else: - next_value = path_from_row_pks(rows[-2], pks, use_rowid) - # If there's a sort or sort_desc, add that value as a prefix - if (sort or sort_desc) and not is_view: - try: - prefix = rows[-2][sort or sort_desc] - except IndexError: - # sort/sort_desc column missing from SELECT - look up value by PK instead - prefix_where_clause = " and ".join( - "[{}] = :pk{}".format(pk, i) for i, pk in enumerate(pks) - ) - prefix_lookup_sql = "select [{}] from [{}] where {}".format( - sort or sort_desc, table_name, prefix_where_clause - ) - prefix = ( - await db.execute( - prefix_lookup_sql, - { - **{ - "pk{}".format(i): rows[-2][pk] - for i, pk in enumerate(pks) - } - }, - ) - ).single_value() - if isinstance(prefix, dict) and "value" in prefix: - prefix = prefix["value"] - if prefix is None: - prefix = "$null" - else: - prefix = tilde_encode(str(prefix)) - next_value = f"{prefix},{next_value}" - added_args = {"_next": next_value} - if sort: - added_args["_sort"] = sort - else: - added_args["_sort_desc"] = sort_desc - else: - added_args = {"_next": next_value} - next_url = self.ds.absolute_url( - request, self.ds.urls.path(path_with_replaced_args(request, added_args)) - ) - rows = rows[:page_size] - - # human_description_en combines filters AND search, if provided - human_description_en = filters.human_description_en( - extra=extra_human_descriptions - ) - - if sort or sort_desc: - sorted_by = "sorted by {}{}".format( - (sort or sort_desc), " descending" if sort_desc else "" - ) - human_description_en = " ".join( - [b for b in [human_description_en, sorted_by] if b] - ) - - async def extra_template(): - nonlocal sort - - display_columns, display_rows = await display_columns_and_rows( - self.ds, - database_name, - table_name, - results.description, - rows, - link_column=not is_view, - truncate_cells=self.ds.setting("truncate_cells_html"), - sortable_columns=await self.sortable_columns_for_table( - database_name, table_name, use_rowid=True - ), - request=request, - ) - metadata = ( - (self.ds.metadata("databases") or {}) - .get(database_name, {}) - .get("tables", {}) - .get(table_name, {}) - ) - self.ds.update_with_inherited_metadata(metadata) - - form_hidden_args = [] - for key in request.args: - if ( - key.startswith("_") - and key not in ("_sort", "_sort_desc", "_search", "_next") - and "__" not in key - ): - for value in request.args.getlist(key): - form_hidden_args.append((key, value)) - - # if no sort specified AND table has a single primary key, - # set sort to that so arrow is displayed - if not sort and not sort_desc: - if 1 == len(pks): - sort = pks[0] - elif use_rowid: - sort = "rowid" - - async def table_actions(): - links = [] - for hook in pm.hook.table_actions( - datasette=self.ds, - table=table_name, - database=database_name, - actor=request.actor, - request=request, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - links.extend(extra_links) - return links - - # filter_columns combine the columns we know are available - # in the table with any additional columns (such as rowid) - # which are available in the query - filter_columns = list(columns) + [ - table_column - for table_column in table_columns - if table_column not in columns - ] - d = { - "table_actions": table_actions, - "use_rowid": use_rowid, - "filters": filters, - "display_columns": display_columns, - "filter_columns": filter_columns, - "display_rows": display_rows, - "facets_timed_out": facets_timed_out, - "sorted_facet_results": sorted( - facet_results.values(), - key=lambda f: (len(f["results"]), f["name"]), - reverse=True, - ), - "form_hidden_args": form_hidden_args, - "is_sortable": any(c["sortable"] for c in display_columns), - "fix_path": self.ds.urls.path, - "path_with_replaced_args": path_with_replaced_args, - "path_with_removed_args": path_with_removed_args, - "append_querystring": append_querystring, - "request": request, - "sort": sort, - "sort_desc": sort_desc, - "disable_sort": is_view, - "custom_table_templates": [ - f"_table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", - f"_table-table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", - "_table.html", - ], - "metadata": metadata, - "view_definition": await db.get_view_definition(table_name), - "table_definition": await db.get_table_definition(table_name), - "datasette_allow_facet": "true" - if self.ds.setting("allow_facet") - else "false", - } - d.update(extra_context_from_filters) - return d - - return ( - { - "database": database_name, - "table": table_name, - "is_view": is_view, - "human_description_en": human_description_en, - "rows": rows[:page_size], - "truncated": results.truncated, - "count": count, - "expanded_columns": expanded_columns, - "expandable_columns": expandable_columns, - "columns": columns, - "primary_keys": pks, - "units": units, - "query": {"sql": sql, "params": params}, - "facet_results": facet_results, - "suggested_facets": suggested_facets, - "next": next_value and str(next_value) or None, - "next_url": next_url, - "private": private, - "allow_execute_sql": await self.ds.permission_allowed( - request.actor, "execute-sql", database_name - ), - }, - extra_template, - ( - f"table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", - "table.html", ), + forward_querystring=False, ) @@ -1337,3 +588,1161 @@ class TableDropView(BaseView): await db.execute_write_fn(drop_table) return Response.json({"ok": True}, status=200) + + +def _get_extras(request): + extra_bits = request.args.getlist("_extra") + extras = set() + for bit in extra_bits: + extras.update(bit.split(",")) + return extras + + +async def _columns_to_select(table_columns, pks, request): + columns = list(table_columns) + if "_col" in request.args: + columns = list(pks) + _cols = request.args.getlist("_col") + bad_columns = [column for column in _cols if column not in table_columns] + if bad_columns: + raise DatasetteError( + "_col={} - invalid columns".format(", ".join(bad_columns)), + status=400, + ) + # De-duplicate maintaining order: + columns.extend(dict.fromkeys(_cols)) + if "_nocol" in request.args: + # Return all columns EXCEPT these + bad_columns = [ + column + for column in request.args.getlist("_nocol") + if (column not in table_columns) or (column in pks) + ] + if bad_columns: + raise DatasetteError( + "_nocol={} - invalid columns".format(", ".join(bad_columns)), + status=400, + ) + tmp_columns = [ + column for column in columns if column not in request.args.getlist("_nocol") + ] + columns = tmp_columns + return columns + + +async def _sortable_columns_for_table(datasette, database_name, table_name, use_rowid): + db = datasette.databases[database_name] + table_metadata = datasette.table_metadata(database_name, table_name) + if "sortable_columns" in table_metadata: + sortable_columns = set(table_metadata["sortable_columns"]) + else: + sortable_columns = set(await db.table_columns(table_name)) + if use_rowid: + sortable_columns.add("rowid") + return sortable_columns + + +async def _sort_order(table_metadata, sortable_columns, request, order_by): + sort = request.args.get("_sort") + sort_desc = request.args.get("_sort_desc") + + if not sort and not sort_desc: + sort = table_metadata.get("sort") + sort_desc = table_metadata.get("sort_desc") + + if sort and sort_desc: + raise DatasetteError( + "Cannot use _sort and _sort_desc at the same time", status=400 + ) + + if sort: + if sort not in sortable_columns: + raise DatasetteError(f"Cannot sort table by {sort}", status=400) + + order_by = escape_sqlite(sort) + + if sort_desc: + if sort_desc not in sortable_columns: + raise DatasetteError(f"Cannot sort table by {sort_desc}", status=400) + + order_by = f"{escape_sqlite(sort_desc)} desc" + + return sort, sort_desc, order_by + + +async def table_view(datasette, request): + await datasette.refresh_schemas() + with tracer.trace_child_tasks(): + response = await table_view_traced(datasette, request) + + # CORS + if datasette.cors: + add_cors_headers(response.headers) + + # Cache TTL header + ttl = request.args.get("_ttl", None) + if ttl is None or not ttl.isdigit(): + ttl = datasette.setting("default_cache_ttl") + + if datasette.cache_headers and response.status == 200: + ttl = int(ttl) + if ttl == 0: + ttl_header = "no-cache" + else: + ttl_header = f"max-age={ttl}" + response.headers["Cache-Control"] = ttl_header + + # Referrer policy + response.headers["Referrer-Policy"] = "no-referrer" + + return response + + +class CannedQueryView(DataView): + def __init__(self, datasette): + self.ds = datasette + + async def post(self, request): + from datasette.app import TableNotFound + + try: + await self.ds.resolve_table(request) + except TableNotFound as e: + # Was this actually a canned query? + canned_query = await self.ds.get_canned_query( + e.database_name, e.table, request.actor + ) + if canned_query: + # Handle POST to a canned query + return await QueryView(self.ds).data( + request, + canned_query["sql"], + metadata=canned_query, + editable=False, + canned_query=e.table, + named_parameters=canned_query.get("params"), + write=bool(canned_query.get("write")), + ) + + return Response.text("Method not allowed", status=405) + + async def data(self, request, **kwargs): + from datasette.app import TableNotFound + + try: + await self.ds.resolve_table(request) + except TableNotFound as not_found: + canned_query = await self.ds.get_canned_query( + not_found.database_name, not_found.table, request.actor + ) + if canned_query: + return await QueryView(self.ds).data( + request, + canned_query["sql"], + metadata=canned_query, + editable=False, + canned_query=not_found.table, + named_parameters=canned_query.get("params"), + write=bool(canned_query.get("write")), + ) + else: + raise + + +async def table_view_traced(datasette, request): + from datasette.app import TableNotFound + + try: + resolved = await datasette.resolve_table(request) + except TableNotFound as not_found: + # Was this actually a canned query? + canned_query = await datasette.get_canned_query( + not_found.database_name, not_found.table, request.actor + ) + # If this is a canned query, not a table, then dispatch to QueryView instead + if canned_query: + if request.method == "POST": + return await CannedQueryView(datasette).post(request) + else: + return await CannedQueryView(datasette).get(request) + else: + raise + + if request.method == "POST": + return Response.text("Method not allowed", status=405) + + format_ = request.url_vars.get("format") or "html" + extra_extras = None + context_for_html_hack = False + default_labels = False + if format_ == "html": + extra_extras = {"_html"} + context_for_html_hack = True + default_labels = True + + view_data = await table_view_data( + datasette, + request, + resolved, + extra_extras=extra_extras, + context_for_html_hack=context_for_html_hack, + default_labels=default_labels, + ) + if isinstance(view_data, Response): + return view_data + data, rows, columns, expanded_columns, sql, next_url = view_data + + # Handle formats from plugins + if format_ == "csv": + + async def fetch_data(request, _next=None): + ( + data, + rows, + columns, + expanded_columns, + sql, + next_url, + ) = await table_view_data( + datasette, + request, + resolved, + extra_extras=extra_extras, + context_for_html_hack=context_for_html_hack, + default_labels=default_labels, + _next=_next, + ) + data["rows"] = rows + data["table"] = resolved.table + data["columns"] = columns + data["expanded_columns"] = expanded_columns + return data, None, None + + return await stream_csv(datasette, fetch_data, request, resolved.db.name) + elif format_ in datasette.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + datasette.renderers[format_][0], + datasette=datasette, + columns=columns, + rows=rows, + sql=sql, + query_name=None, + database=resolved.db.name, + table=resolved.table, + request=request, + view_name="table", + # These will be deprecated in Datasette 1.0: + args=request.args, + data=data, + ) + if asyncio.iscoroutine(result): + result = await result + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + r = Response( + body=result.get("body"), + status=result.get("status_code") or 200, + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + r = result + # if status_code is not None: + # # Over-ride the status code + # r.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + headers = {} + templates = [ + f"table-{to_css_class(resolved.db.name)}-{to_css_class(resolved.table)}.html", + "table.html", + ] + template = datasette.jinja_env.select_template(templates) + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + headers.update( + { + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + } + ) + r = Response.html( + await datasette.render_template( + template, + dict( + data, + append_querystring=append_querystring, + path_with_replaced_args=path_with_replaced_args, + fix_path=datasette.urls.path, + settings=datasette.settings_dict(), + # TODO: review up all of these hacks: + alternate_url_json=alternate_url_json, + datasette_allow_facet=( + "true" if datasette.setting("allow_facet") else "false" + ), + is_sortable=any(c["sortable"] for c in data["display_columns"]), + allow_execute_sql=await datasette.permission_allowed( + request.actor, "execute-sql", resolved.db.name + ), + query_ms=1.2, + select_templates=[ + f"{'*' if template_name == template.name else ''}{template_name}" + for template_name in templates + ], + ), + request=request, + view_name="table", + ), + headers=headers, + ) + else: + assert False, "Invalid format: {}".format(format_) + if next_url: + r.headers["link"] = f'<{next_url}>; rel="next"' + return r + + +async def table_view_data( + datasette, + request, + resolved, + extra_extras=None, + context_for_html_hack=False, + default_labels=False, + _next=None, +): + extra_extras = extra_extras or set() + # We have a table or view + db = resolved.db + database_name = resolved.db.name + table_name = resolved.table + is_view = resolved.is_view + + # Can this user view it? + visible, private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-table", (database_name, table_name)), + ("view-database", database_name), + "view-instance", + ], + ) + if not visible: + raise Forbidden("You do not have permission to view this table") + + # Redirect based on request.args, if necessary + redirect_response = await _redirect_if_needed(datasette, request, resolved) + if redirect_response: + return redirect_response + + # Introspect columns and primary keys for table + pks = await db.primary_keys(table_name) + table_columns = await db.table_columns(table_name) + + # Take ?_col= and ?_nocol= into account + specified_columns = await _columns_to_select(table_columns, pks, request) + select_specified_columns = ", ".join(escape_sqlite(t) for t in specified_columns) + select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns) + + # rowid tables (no specified primary key) need a different SELECT + use_rowid = not pks and not is_view + order_by = "" + if use_rowid: + select_specified_columns = f"rowid, {select_specified_columns}" + select_all_columns = f"rowid, {select_all_columns}" + order_by = "rowid" + order_by_pks = "rowid" + else: + order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) + order_by = order_by_pks + + if is_view: + order_by = "" + + # TODO: This logic should turn into logic about which ?_extras get + # executed instead: + nocount = request.args.get("_nocount") + nofacet = request.args.get("_nofacet") + nosuggest = request.args.get("_nosuggest") + if request.args.get("_shape") in ("array", "object"): + nocount = True + nofacet = True + + table_metadata = datasette.table_metadata(database_name, table_name) + units = table_metadata.get("units", {}) + + # Arguments that start with _ and don't contain a __ are + # special - things like ?_search= - and should not be + # treated as filters. + filter_args = [] + for key in request.args: + if not (key.startswith("_") and "__" not in key): + for v in request.args.getlist(key): + filter_args.append((key, v)) + + # Build where clauses from query string arguments + filters = Filters(sorted(filter_args), units, ureg) + where_clauses, params = filters.build_where_clauses(table_name) + + # Execute filters_from_request plugin hooks - including the default + # ones that live in datasette/filters.py + extra_context_from_filters = {} + extra_human_descriptions = [] + + for hook in pm.hook.filters_from_request( + request=request, + table=table_name, + database=database_name, + datasette=datasette, + ): + filter_arguments = await await_me_maybe(hook) + if filter_arguments: + where_clauses.extend(filter_arguments.where_clauses) + params.update(filter_arguments.params) + extra_human_descriptions.extend(filter_arguments.human_descriptions) + extra_context_from_filters.update(filter_arguments.extra_context) + + # Deal with custom sort orders + sortable_columns = await _sortable_columns_for_table( + datasette, database_name, table_name, use_rowid + ) + + sort, sort_desc, order_by = await _sort_order( + table_metadata, sortable_columns, request, order_by + ) + + from_sql = "from {table_name} {where}".format( + table_name=escape_sqlite(table_name), + where=("where {} ".format(" and ".join(where_clauses))) + if where_clauses + else "", + ) + # Copy of params so we can mutate them later: + from_sql_params = dict(**params) + + count_sql = f"select count(*) {from_sql}" + + # Handle pagination driven by ?_next= + _next = _next or request.args.get("_next") + + offset = "" + if _next: + sort_value = None + if is_view: + # _next is an offset + offset = f" offset {int(_next)}" + else: + components = urlsafe_components(_next) + # If a sort order is applied and there are multiple components, + # the first of these is the sort value + if (sort or sort_desc) and (len(components) > 1): + sort_value = components[0] + # Special case for if non-urlencoded first token was $null + if _next.split(",")[0] == "$null": + sort_value = None + components = components[1:] + + # Figure out the SQL for next-based-on-primary-key first + next_by_pk_clauses = [] + if use_rowid: + next_by_pk_clauses.append(f"rowid > :p{len(params)}") + params[f"p{len(params)}"] = components[0] + else: + # Apply the tie-breaker based on primary keys + if len(components) == len(pks): + param_len = len(params) + next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) + for i, pk_value in enumerate(components): + params[f"p{param_len + i}"] = pk_value + + # Now add the sort SQL, which may incorporate next_by_pk_clauses + if sort or sort_desc: + if sort_value is None: + if sort_desc: + # Just items where column is null ordered by pk + where_clauses.append( + "({column} is null and {next_clauses})".format( + column=escape_sqlite(sort_desc), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} is not null or ({column} is null and {next_clauses}))".format( + column=escape_sqlite(sort), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( + column=escape_sqlite(sort or sort_desc), + op=">" if sort else "<", + p=len(params), + extra_desc_only="" + if sort + else " or {column2} is null".format( + column2=escape_sqlite(sort or sort_desc) + ), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + params[f"p{len(params)}"] = sort_value + order_by = f"{order_by}, {order_by_pks}" + else: + where_clauses.extend(next_by_pk_clauses) + + where_clause = "" + if where_clauses: + where_clause = f"where {' and '.join(where_clauses)} " + + if order_by: + order_by = f"order by {order_by}" + + extra_args = {} + # Handle ?_size=500 + # TODO: This was: + # page_size = _size or request.args.get("_size") or table_metadata.get("size") + page_size = request.args.get("_size") or table_metadata.get("size") + if page_size: + if page_size == "max": + page_size = datasette.max_returned_rows + try: + page_size = int(page_size) + if page_size < 0: + raise ValueError + + except ValueError: + raise BadRequest("_size must be a positive integer") + + if page_size > datasette.max_returned_rows: + raise BadRequest(f"_size must be <= {datasette.max_returned_rows}") + + extra_args["page_size"] = page_size + else: + page_size = datasette.page_size + + # Facets are calculated against SQL without order by or limit + sql_no_order_no_limit = ( + "select {select_all_columns} from {table_name} {where}".format( + select_all_columns=select_all_columns, + table_name=escape_sqlite(table_name), + where=where_clause, + ) + ) + + # This is the SQL that populates the main table on the page + sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format( + select_specified_columns=select_specified_columns, + table_name=escape_sqlite(table_name), + where=where_clause, + order_by=order_by, + page_size=page_size + 1, + offset=offset, + ) + + if request.args.get("_timelimit"): + extra_args["custom_time_limit"] = int(request.args.get("_timelimit")) + + # Execute the main query! + try: + results = await db.execute(sql, params, truncate=True, **extra_args) + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + + except sqlite3.OperationalError as e: + raise DatasetteError(str(e)) + + columns = [r[0] for r in results.description] + rows = list(results.rows) + + # Expand labeled columns if requested + expanded_columns = [] + # List of (fk_dict, label_column-or-None) pairs for that table + expandable_columns = [] + for fk in await db.foreign_keys_for_table(table_name): + label_column = await db.label_column_for_table(fk["other_table"]) + expandable_columns.append((fk, label_column)) + + columns_to_expand = None + try: + all_labels = value_as_boolean(request.args.get("_labels", "")) + except ValueError: + all_labels = default_labels + # Check for explicit _label= + if "_label" in request.args: + columns_to_expand = request.args.getlist("_label") + if columns_to_expand is None and all_labels: + # expand all columns with foreign keys + columns_to_expand = [fk["column"] for fk, _ in expandable_columns] + + if columns_to_expand: + expanded_labels = {} + for fk, _ in expandable_columns: + column = fk["column"] + if column not in columns_to_expand: + continue + if column not in columns: + continue + expanded_columns.append(column) + # Gather the values + column_index = columns.index(column) + values = [row[column_index] for row in rows] + # Expand them + expanded_labels.update( + await datasette.expand_foreign_keys( + database_name, table_name, column, values + ) + ) + if expanded_labels: + # Rewrite the rows + new_rows = [] + for row in rows: + new_row = CustomRow(columns) + for column in row.keys(): + value = row[column] + if (column, value) in expanded_labels and value is not None: + new_row[column] = { + "value": value, + "label": expanded_labels[(column, value)], + } + else: + new_row[column] = value + new_rows.append(new_row) + rows = new_rows + + _next = request.args.get("_next") + + # Pagination next link + next_value, next_url = await _next_value_and_url( + datasette, + db, + request, + table_name, + _next, + rows, + pks, + use_rowid, + sort, + sort_desc, + page_size, + is_view, + ) + rows = rows[:page_size] + + # For performance profiling purposes, ?_noparallel=1 turns off asyncio.gather + gather = _gather_sequential if request.args.get("_noparallel") else _gather_parallel + + # Resolve extras + extras = _get_extras(request) + if any(k for k in request.args.keys() if k == "_facet" or k.startswith("_facet_")): + extras.add("facet_results") + if request.args.get("_shape") == "object": + extras.add("primary_keys") + if extra_extras: + extras.update(extra_extras) + + async def extra_count(): + "Total count of rows matching these filters" + # Calculate the total count for this query + count = None + if ( + not db.is_mutable + and datasette.inspect_data + and count_sql == f"select count(*) from {table_name} " + ): + # We can use a previously cached table row count + try: + count = datasette.inspect_data[database_name]["tables"][table_name][ + "count" + ] + except KeyError: + pass + + # Otherwise run a select count(*) ... + if count_sql and count is None and not nocount: + try: + count_rows = list(await db.execute(count_sql, from_sql_params)) + count = count_rows[0][0] + except QueryInterrupted: + pass + return count + + async def facet_instances(extra_count): + facet_instances = [] + facet_classes = list( + itertools.chain.from_iterable(pm.hook.register_facet_classes()) + ) + for facet_class in facet_classes: + facet_instances.append( + facet_class( + datasette, + request, + database_name, + sql=sql_no_order_no_limit, + params=params, + table=table_name, + metadata=table_metadata, + row_count=extra_count, + ) + ) + return facet_instances + + async def extra_facet_results(facet_instances): + "Results of facets calculated against this data" + facet_results = {} + facets_timed_out = [] + + if not nofacet: + # Run them in parallel + facet_awaitables = [facet.facet_results() for facet in facet_instances] + facet_awaitable_results = await gather(*facet_awaitables) + for ( + instance_facet_results, + instance_facets_timed_out, + ) in facet_awaitable_results: + for facet_info in instance_facet_results: + base_key = facet_info["name"] + key = base_key + i = 1 + while key in facet_results: + i += 1 + key = f"{base_key}_{i}" + facet_results[key] = facet_info + facets_timed_out.extend(instance_facets_timed_out) + + return { + "results": facet_results, + "timed_out": facets_timed_out, + } + + async def extra_suggested_facets(facet_instances): + "Suggestions for facets that might return interesting results" + suggested_facets = [] + # Calculate suggested facets + if ( + datasette.setting("suggest_facets") + and datasette.setting("allow_facet") + and not _next + and not nofacet + and not nosuggest + ): + # Run them in parallel + facet_suggest_awaitables = [facet.suggest() for facet in facet_instances] + for suggest_result in await gather(*facet_suggest_awaitables): + suggested_facets.extend(suggest_result) + return suggested_facets + + # Faceting + if not datasette.setting("allow_facet") and any( + arg.startswith("_facet") for arg in request.args + ): + raise BadRequest("_facet= is not allowed") + + # human_description_en combines filters AND search, if provided + async def extra_human_description_en(): + "Human-readable description of the filters" + human_description_en = filters.human_description_en( + extra=extra_human_descriptions + ) + if sort or sort_desc: + human_description_en = " ".join( + [b for b in [human_description_en, sorted_by] if b] + ) + return human_description_en + + if sort or sort_desc: + sorted_by = "sorted by {}{}".format( + (sort or sort_desc), " descending" if sort_desc else "" + ) + + async def extra_next_url(): + "Full URL for the next page of results" + return next_url + + async def extra_columns(): + "Column names returned by this query" + return columns + + async def extra_primary_keys(): + "Primary keys for this table" + return pks + + async def extra_table_actions(): + async def table_actions(): + links = [] + for hook in pm.hook.table_actions( + datasette=datasette, + table=table_name, + database=database_name, + actor=request.actor, + request=request, + ): + extra_links = await await_me_maybe(hook) + if extra_links: + links.extend(extra_links) + return links + + return table_actions + + async def extra_is_view(): + return is_view + + async def extra_debug(): + "Extra debug information" + return { + "resolved": repr(resolved), + "url_vars": request.url_vars, + "nofacet": nofacet, + "nosuggest": nosuggest, + } + + async def extra_request(): + "Full information about the request" + return { + "url": request.url, + "path": request.path, + "full_path": request.full_path, + "host": request.host, + "args": request.args._data, + } + + async def run_display_columns_and_rows(): + display_columns, display_rows = await display_columns_and_rows( + datasette, + database_name, + table_name, + results.description, + rows, + link_column=not is_view, + truncate_cells=datasette.setting("truncate_cells_html"), + sortable_columns=sortable_columns, + request=request, + ) + return { + "columns": display_columns, + "rows": display_rows, + } + + async def extra_display_columns(run_display_columns_and_rows): + return run_display_columns_and_rows["columns"] + + async def extra_display_rows(run_display_columns_and_rows): + return run_display_columns_and_rows["rows"] + + async def extra_query(): + "Details of the underlying SQL query" + return { + "sql": sql, + "params": params, + } + + async def extra_metadata(): + "Metadata about the table and database" + metadata = ( + (datasette.metadata("databases") or {}) + .get(database_name, {}) + .get("tables", {}) + .get(table_name, {}) + ) + datasette.update_with_inherited_metadata(metadata) + return metadata + + async def extra_database(): + return database_name + + async def extra_table(): + return table_name + + async def extra_database_color(): + return lambda _: "ff0000" + + async def extra_form_hidden_args(): + form_hidden_args = [] + for key in request.args: + if ( + key.startswith("_") + and key not in ("_sort", "_sort_desc", "_search", "_next") + and "__" not in key + ): + for value in request.args.getlist(key): + form_hidden_args.append((key, value)) + return form_hidden_args + + async def extra_filters(): + return filters + + async def extra_custom_table_templates(): + return [ + f"_table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", + f"_table-table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", + "_table.html", + ] + + async def extra_sorted_facet_results(extra_facet_results): + return sorted( + extra_facet_results["results"].values(), + key=lambda f: (len(f["results"]), f["name"]), + reverse=True, + ) + + async def extra_table_definition(): + return await db.get_table_definition(table_name) + + async def extra_view_definition(): + return await db.get_view_definition(table_name) + + async def extra_renderers(extra_expandable_columns, extra_query): + renderers = {} + url_labels_extra = {} + if extra_expandable_columns: + url_labels_extra = {"_labels": "on"} + for key, (_, can_render) in datasette.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=datasette, + columns=columns or [], + rows=rows or [], + sql=extra_query.get("sql", None), + query_name=None, + database=database_name, + table=table_name, + request=request, + view_name="table", + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = datasette.urls.path( + path_with_format( + request=request, format=key, extra_qs={**url_labels_extra} + ) + ) + return renderers + + async def extra_private(): + return private + + async def extra_expandable_columns(): + expandables = [] + db = datasette.databases[database_name] + for fk in await db.foreign_keys_for_table(table_name): + label_column = await db.label_column_for_table(fk["other_table"]) + expandables.append((fk, label_column)) + return expandables + + async def extra_extras(): + "Available ?_extra= blocks" + return { + "available": [ + { + "name": key[len("extra_") :], + "doc": fn.__doc__, + } + for key, fn in registry._registry.items() + if key.startswith("extra_") + ], + "selected": list(extras), + } + + async def extra_facets_timed_out(extra_facet_results): + return extra_facet_results["timed_out"] + + bundles = { + "html": [ + "suggested_facets", + "facet_results", + "facets_timed_out", + "count", + "human_description_en", + "next_url", + "metadata", + "query", + "columns", + "display_columns", + "display_rows", + "database", + "table", + "database_color", + "table_actions", + "filters", + "renderers", + "custom_table_templates", + "sorted_facet_results", + "table_definition", + "view_definition", + "is_view", + "private", + "primary_keys", + "expandable_columns", + "form_hidden_args", + ] + } + + for key, values in bundles.items(): + if f"_{key}" in extras: + extras.update(values) + extras.discard(f"_{key}") + + registry = Registry( + extra_count, + extra_facet_results, + extra_facets_timed_out, + extra_suggested_facets, + facet_instances, + extra_human_description_en, + extra_next_url, + extra_columns, + extra_primary_keys, + run_display_columns_and_rows, + extra_display_columns, + extra_display_rows, + extra_debug, + extra_request, + extra_query, + extra_metadata, + extra_extras, + extra_database, + extra_table, + extra_database_color, + extra_table_actions, + extra_filters, + extra_renderers, + extra_custom_table_templates, + extra_sorted_facet_results, + extra_table_definition, + extra_view_definition, + extra_is_view, + extra_private, + extra_expandable_columns, + extra_form_hidden_args, + ) + + results = await registry.resolve_multi( + ["extra_{}".format(extra) for extra in extras] + ) + data = { + "ok": True, + "next": next_value and str(next_value) or None, + } + data.update( + { + key.replace("extra_", ""): value + for key, value in results.items() + if key.startswith("extra_") and key.replace("extra_", "") in extras + } + ) + raw_sqlite_rows = rows[:page_size] + data["rows"] = [dict(r) for r in raw_sqlite_rows] + + if context_for_html_hack: + data.update(extra_context_from_filters) + # filter_columns combine the columns we know are available + # in the table with any additional columns (such as rowid) + # which are available in the query + data["filter_columns"] = list(columns) + [ + table_column + for table_column in table_columns + if table_column not in columns + ] + url_labels_extra = {} + if data.get("expandable_columns"): + url_labels_extra = {"_labels": "on"} + url_csv_args = {"_size": "max", **url_labels_extra} + url_csv = datasette.urls.path( + path_with_format(request=request, format="csv", extra_qs=url_csv_args) + ) + url_csv_path = url_csv.split("?")[0] + data.update( + { + "url_csv": url_csv, + "url_csv_path": url_csv_path, + "url_csv_hidden_args": [ + (key, value) + for key, value in urllib.parse.parse_qsl(request.query_string) + if key not in ("_labels", "_facet", "_size") + ] + + [("_size", "max")], + } + ) + # if no sort specified AND table has a single primary key, + # set sort to that so arrow is displayed + if not sort and not sort_desc: + if 1 == len(pks): + sort = pks[0] + elif use_rowid: + sort = "rowid" + data["sort"] = sort + data["sort_desc"] = sort_desc + + return data, rows[:page_size], columns, expanded_columns, sql, next_url + + +async def _next_value_and_url( + datasette, + db, + request, + table_name, + _next, + rows, + pks, + use_rowid, + sort, + sort_desc, + page_size, + is_view, +): + next_value = None + next_url = None + if 0 < page_size < len(rows): + if is_view: + next_value = int(_next or 0) + page_size + else: + next_value = path_from_row_pks(rows[-2], pks, use_rowid) + # If there's a sort or sort_desc, add that value as a prefix + if (sort or sort_desc) and not is_view: + try: + prefix = rows[-2][sort or sort_desc] + except IndexError: + # sort/sort_desc column missing from SELECT - look up value by PK instead + prefix_where_clause = " and ".join( + "[{}] = :pk{}".format(pk, i) for i, pk in enumerate(pks) + ) + prefix_lookup_sql = "select [{}] from [{}] where {}".format( + sort or sort_desc, table_name, prefix_where_clause + ) + prefix = ( + await db.execute( + prefix_lookup_sql, + { + **{ + "pk{}".format(i): rows[-2][pk] + for i, pk in enumerate(pks) + } + }, + ) + ).single_value() + if isinstance(prefix, dict) and "value" in prefix: + prefix = prefix["value"] + if prefix is None: + prefix = "$null" + else: + prefix = tilde_encode(str(prefix)) + next_value = f"{prefix},{next_value}" + added_args = {"_next": next_value} + if sort: + added_args["_sort"] = sort + else: + added_args["_sort_desc"] = sort_desc + else: + added_args = {"_next": next_value} + next_url = datasette.absolute_url( + request, datasette.urls.path(path_with_replaced_args(request, added_args)) + ) + return next_value, next_url \ No newline at end of file diff --git a/setup.py b/setup.py index d424b635..a6f41456 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ setup( "mergedeep>=1.1.1", "itsdangerous>=1.1", "sqlite-utils>=3.30", + "asyncinject>=0.5", ], entry_points=""" [console_scripts] diff --git a/tests/test_api.py b/tests/test_api.py index 5a751487..780e9fa5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -896,9 +896,11 @@ def test_config_cache_size(app_client_larger_cache_size): def test_config_force_https_urls(): with make_app_client(settings={"force_https_urls": True}) as client: - response = client.get("/fixtures/facetable.json?_size=3&_facet=state") + response = client.get( + "/fixtures/facetable.json?_size=3&_facet=state&_extra=next_url,suggested_facets" + ) assert response.json["next_url"].startswith("https://") - assert response.json["facet_results"]["state"]["results"][0][ + assert response.json["facet_results"]["results"]["state"]["results"][0][ "toggle_url" ].startswith("https://") assert response.json["suggested_facets"][0]["toggle_url"].startswith("https://") @@ -981,7 +983,9 @@ def test_common_prefix_database_names(app_client_conflicting_database_names): def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): - response = app_client_immutable_and_inspect_file.get("/fixtures/sortable.json") + response = app_client_immutable_and_inspect_file.get( + "/fixtures/sortable.json?_extra=count" + ) assert response.json["count"] == 100 diff --git a/tests/test_facets.py b/tests/test_facets.py index d264f534..48cc0ff2 100644 --- a/tests/test_facets.py +++ b/tests/test_facets.py @@ -419,7 +419,7 @@ async def test_array_facet_handle_duplicate_tags(): ) response = await ds.client.get("/test_array_facet/otters.json?_facet_array=tags") - assert response.json()["facet_results"]["tags"] == { + assert response.json()["facet_results"]["results"]["tags"] == { "name": "tags", "type": "array", "results": [ @@ -517,13 +517,13 @@ async def test_json_array_with_blanks_and_nulls(): await db.execute_write("create table foo(json_column text)") for value in ('["a", "b", "c"]', '["a", "b"]', "", None): await db.execute_write("insert into foo (json_column) values (?)", [value]) - response = await ds.client.get("/test_json_array/foo.json") + response = await ds.client.get("/test_json_array/foo.json?_extra=suggested_facets") data = response.json() assert data["suggested_facets"] == [ { "name": "json_column", "type": "array", - "toggle_url": "http://localhost/test_json_array/foo.json?_facet_array=json_column", + "toggle_url": "http://localhost/test_json_array/foo.json?_extra=suggested_facets&_facet_array=json_column", } ] @@ -539,27 +539,29 @@ async def test_facet_size(): "insert into neighbourhoods (city, neighbourhood) values (?, ?)", ["City {}".format(i), "Neighbourhood {}".format(j)], ) - response = await ds.client.get("/test_facet_size/neighbourhoods.json") + response = await ds.client.get( + "/test_facet_size/neighbourhoods.json?_extra=suggested_facets" + ) data = response.json() assert data["suggested_facets"] == [ { "name": "neighbourhood", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet=neighbourhood", + "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_extra=suggested_facets&_facet=neighbourhood", } ] # Bump up _facet_size= to suggest city too response2 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet_size=50" + "/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets" ) data2 = response2.json() assert sorted(data2["suggested_facets"], key=lambda f: f["name"]) == [ { "name": "city", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city", + "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=city", }, { "name": "neighbourhood", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=neighbourhood", + "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=neighbourhood", }, ] # Facet by city should return expected number of results @@ -567,20 +569,20 @@ async def test_facet_size(): "/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city" ) data3 = response3.json() - assert len(data3["facet_results"]["city"]["results"]) == 50 + assert len(data3["facet_results"]["results"]["city"]["results"]) == 50 # Reduce max_returned_rows and check that it's respected ds._settings["max_returned_rows"] = 20 response4 = await ds.client.get( "/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city" ) data4 = response4.json() - assert len(data4["facet_results"]["city"]["results"]) == 20 + assert len(data4["facet_results"]["results"]["city"]["results"]) == 20 # Test _facet_size=max response5 = await ds.client.get( "/test_facet_size/neighbourhoods.json?_facet_size=max&_facet=city" ) data5 = response5.json() - assert len(data5["facet_results"]["city"]["results"]) == 20 + assert len(data5["facet_results"]["results"]["city"]["results"]) == 20 # Now try messing with facet_size in the table metadata orig_metadata = ds._metadata_local try: @@ -593,7 +595,7 @@ async def test_facet_size(): "/test_facet_size/neighbourhoods.json?_facet=city" ) data6 = response6.json() - assert len(data6["facet_results"]["city"]["results"]) == 6 + assert len(data6["facet_results"]["results"]["city"]["results"]) == 6 # Setting it to max bumps it up to 50 again ds._metadata_local["databases"]["test_facet_size"]["tables"]["neighbourhoods"][ "facet_size" @@ -601,7 +603,7 @@ async def test_facet_size(): data7 = ( await ds.client.get("/test_facet_size/neighbourhoods.json?_facet=city") ).json() - assert len(data7["facet_results"]["city"]["results"]) == 20 + assert len(data7["facet_results"]["results"]["city"]["results"]) == 20 finally: ds._metadata_local = orig_metadata @@ -635,7 +637,7 @@ async def test_conflicting_facet_names_json(ds_client): "/fixtures/facetable.json?_facet=created&_facet_date=created" "&_facet=tags&_facet_array=tags" ) - assert set(response.json()["facet_results"].keys()) == { + assert set(response.json()["facet_results"]["results"].keys()) == { "created", "tags", "created_2", diff --git a/tests/test_filters.py b/tests/test_filters.py index 01b0ec6f..5b2e9636 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -82,13 +82,11 @@ async def test_through_filters_from_request(ds_client): request = Request.fake( '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' ) - filter_args = await ( - through_filters( - request=request, - datasette=ds_client.ds, - table="roadside_attractions", - database="fixtures", - ) + filter_args = await through_filters( + request=request, + datasette=ds_client.ds, + table="roadside_attractions", + database="fixtures", )() assert filter_args.where_clauses == [ "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" @@ -105,13 +103,11 @@ async def test_through_filters_from_request(ds_client): request = Request.fake( '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' ) - filter_args = await ( - through_filters( - request=request, - datasette=ds_client.ds, - table="roadside_attractions", - database="fixtures", - ) + filter_args = await through_filters( + request=request, + datasette=ds_client.ds, + table="roadside_attractions", + database="fixtures", )() assert filter_args.where_clauses == [ "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" @@ -127,12 +123,10 @@ async def test_through_filters_from_request(ds_client): async def test_where_filters_from_request(ds_client): await ds_client.ds.invoke_startup() request = Request.fake("/?_where=pk+>+3") - filter_args = await ( - where_filters( - request=request, - datasette=ds_client.ds, - database="fixtures", - ) + filter_args = await where_filters( + request=request, + datasette=ds_client.ds, + database="fixtures", )() assert filter_args.where_clauses == ["pk > 3"] assert filter_args.params == {} @@ -145,13 +139,11 @@ async def test_where_filters_from_request(ds_client): @pytest.mark.asyncio async def test_search_filters_from_request(ds_client): request = Request.fake("/?_search=bobcat") - filter_args = await ( - search_filters( - request=request, - datasette=ds_client.ds, - database="fixtures", - table="searchable", - ) + filter_args = await search_filters( + request=request, + datasette=ds_client.ds, + database="fixtures", + table="searchable", )() assert filter_args.where_clauses == [ "rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))" diff --git a/tests/test_load_extensions.py b/tests/test_load_extensions.py index 0e39f566..4007e0be 100644 --- a/tests/test_load_extensions.py +++ b/tests/test_load_extensions.py @@ -8,6 +8,7 @@ from pathlib import Path # this resolves to "./ext", which is enough for SQLite to calculate the rest COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext") + # See if ext.c has been compiled, based off the different possible suffixes. def has_compiled_ext(): for ext in ["dylib", "so", "dll"]: @@ -20,7 +21,6 @@ def has_compiled_ext(): @pytest.mark.asyncio @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") async def test_load_extension_default_entrypoint(): - # The default entrypoint only loads a() and NOT b() or c(), so those # should fail. ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH]) @@ -41,7 +41,6 @@ async def test_load_extension_default_entrypoint(): @pytest.mark.asyncio @pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") async def test_load_extension_multiple_entrypoints(): - # Load in the default entrypoint and the other 2 custom entrypoints, now # all a(), b(), and c() should run successfully. ds = Datasette( diff --git a/tests/test_plugins.py b/tests/test_plugins.py index eec02e10..71b710f9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -595,42 +595,42 @@ def test_hook_publish_subcommand(): @pytest.mark.asyncio async def test_hook_register_facet_classes(ds_client): response = await ds_client.get( - "/fixtures/compound_three_primary_keys.json?_dummy_facet=1" + "/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets" ) - assert [ + assert response.json()["suggested_facets"] == [ { "name": "pk1", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet_dummy=pk1", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk1", "type": "dummy", }, { "name": "pk2", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet_dummy=pk2", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk2", "type": "dummy", }, { "name": "pk3", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet_dummy=pk3", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk3", "type": "dummy", }, { "name": "content", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet_dummy=content", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=content", "type": "dummy", }, { "name": "pk1", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet=pk1", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk1", }, { "name": "pk2", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet=pk2", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk2", }, { "name": "pk3", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_facet=pk3", + "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk3", }, - ] == response.json()["suggested_facets"] + ] @pytest.mark.asyncio diff --git a/tests/test_routes.py b/tests/test_routes.py index d467abe1..85945dec 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -11,7 +11,7 @@ def routes(): @pytest.mark.parametrize( - "path,expected_class,expected_matches", + "path,expected_name,expected_matches", ( ("/", "IndexView", {"format": None}), ("/foo", "DatabaseView", {"format": None, "database": "foo"}), @@ -20,17 +20,17 @@ def routes(): ("/foo.humbug", "DatabaseView", {"format": "humbug", "database": "foo"}), ( "/foo/humbug", - "TableView", + "table_view", {"database": "foo", "table": "humbug", "format": None}, ), ( "/foo/humbug.json", - "TableView", + "table_view", {"database": "foo", "table": "humbug", "format": "json"}, ), ( "/foo/humbug.blah", - "TableView", + "table_view", {"database": "foo", "table": "humbug", "format": "blah"}, ), ( @@ -47,12 +47,14 @@ def routes(): ("/-/metadata", "JsonDataView", {"format": None}), ), ) -def test_routes(routes, path, expected_class, expected_matches): +def test_routes(routes, path, expected_name, expected_matches): match, view = resolve_routes(routes, path) - if expected_class is None: + if expected_name is None: assert match is None else: - assert view.view_class.__name__ == expected_class + assert ( + view.__name__ == expected_name or view.view_class.__name__ == expected_name + ) assert match.groupdict() == expected_matches diff --git a/tests/test_table_api.py b/tests/test_table_api.py index 9e9578bf..cd664ffb 100644 --- a/tests/test_table_api.py +++ b/tests/test_table_api.py @@ -15,7 +15,7 @@ import urllib @pytest.mark.asyncio async def test_table_json(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=objects") + response = await ds_client.get("/fixtures/simple_primary_key.json?_extra=query") assert response.status_code == 200 data = response.json() assert ( @@ -198,6 +198,10 @@ async def test_paginate_tables_and_views( fetched = [] count = 0 while path: + if "?" in path: + path += "&_extra=next_url" + else: + path += "?_extra=next_url" response = await ds_client.get(path) assert response.status_code == 200 count += 1 @@ -230,7 +234,9 @@ async def test_validate_page_size(ds_client, path, expected_error): @pytest.mark.asyncio async def test_page_size_zero(ds_client): """For _size=0 we return the counts, empty rows and no continuation token""" - response = await ds_client.get("/fixtures/no_primary_key.json?_size=0") + response = await ds_client.get( + "/fixtures/no_primary_key.json?_size=0&_extra=count,next_url" + ) assert response.status_code == 200 assert [] == response.json()["rows"] assert 201 == response.json()["count"] @@ -241,7 +247,7 @@ async def test_page_size_zero(ds_client): @pytest.mark.asyncio async def test_paginate_compound_keys(ds_client): fetched = [] - path = "/fixtures/compound_three_primary_keys.json?_shape=objects" + path = "/fixtures/compound_three_primary_keys.json?_shape=objects&_extra=next_url" page = 0 while path: page += 1 @@ -262,9 +268,7 @@ async def test_paginate_compound_keys(ds_client): @pytest.mark.asyncio async def test_paginate_compound_keys_with_extra_filters(ds_client): fetched = [] - path = ( - "/fixtures/compound_three_primary_keys.json?content__contains=d&_shape=objects" - ) + path = "/fixtures/compound_three_primary_keys.json?content__contains=d&_shape=objects&_extra=next_url" page = 0 while path: page += 1 @@ -315,7 +319,7 @@ async def test_paginate_compound_keys_with_extra_filters(ds_client): ], ) async def test_sortable(ds_client, query_string, sort_key, human_description_en): - path = f"/fixtures/sortable.json?_shape=objects&{query_string}" + path = f"/fixtures/sortable.json?_shape=objects&_extra=human_description_en,next_url&{query_string}" fetched = [] page = 0 while path: @@ -338,6 +342,7 @@ async def test_sortable_and_filtered(ds_client): path = ( "/fixtures/sortable.json" "?content__contains=d&_sort_desc=sortable&_shape=objects" + "&_extra=human_description_en,count" ) response = await ds_client.get(path) fetched = response.json()["rows"] @@ -660,7 +665,9 @@ def test_table_filter_extra_where_disabled_if_no_sql_allowed(): async def test_table_through(ds_client): # Just the museums: response = await ds_client.get( - '/fixtures/roadside_attractions.json?_shape=arrays&_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' + "/fixtures/roadside_attractions.json?_shape=arrays" + '&_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' + "&_extra=human_description_en" ) assert response.json()["rows"] == [ [ @@ -712,6 +719,7 @@ async def test_view(ds_client): ] +@pytest.mark.xfail @pytest.mark.asyncio async def test_unit_filters(ds_client): response = await ds_client.get( @@ -731,7 +739,7 @@ def test_page_size_matching_max_returned_rows( app_client_returned_rows_matches_page_size, ): fetched = [] - path = "/fixtures/no_primary_key.json" + path = "/fixtures/no_primary_key.json?_extra=next_url" while path: response = app_client_returned_rows_matches_page_size.get(path) fetched.extend(response.json["rows"]) @@ -911,12 +919,42 @@ async def test_facets(ds_client, path, expected_facet_results): response = await ds_client.get(path) facet_results = response.json()["facet_results"] # We only compare the querystring portion of the taggle_url - for facet_name, facet_info in facet_results.items(): + for facet_name, facet_info in facet_results["results"].items(): assert facet_name == facet_info["name"] assert False is facet_info["truncated"] for facet_value in facet_info["results"]: facet_value["toggle_url"] = facet_value["toggle_url"].split("?")[1] - assert expected_facet_results == facet_results + assert expected_facet_results == facet_results["results"] + + +@pytest.mark.asyncio +@pytest.mark.skipif(not detect_json1(), reason="requires JSON1 extension") +async def test_facets_array(ds_client): + response = await ds_client.get("/fixtures/facetable.json?_facet_array=tags") + facet_results = response.json()["facet_results"] + assert facet_results["results"]["tags"]["results"] == [ + { + "value": "tag1", + "label": "tag1", + "count": 2, + "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag1", + "selected": False, + }, + { + "value": "tag2", + "label": "tag2", + "count": 1, + "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag2", + "selected": False, + }, + { + "value": "tag3", + "label": "tag3", + "count": 1, + "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag3", + "selected": False, + }, + ] @pytest.mark.asyncio @@ -926,58 +964,83 @@ async def test_suggested_facets(ds_client): "name": suggestion["name"], "querystring": suggestion["toggle_url"].split("?")[-1], } - for suggestion in (await ds_client.get("/fixtures/facetable.json")).json()[ - "suggested_facets" - ] + for suggestion in ( + await ds_client.get("/fixtures/facetable.json?_extra=suggested_facets") + ).json()["suggested_facets"] ] expected = [ - {"name": "created", "querystring": "_facet=created"}, - {"name": "planet_int", "querystring": "_facet=planet_int"}, - {"name": "on_earth", "querystring": "_facet=on_earth"}, - {"name": "state", "querystring": "_facet=state"}, - {"name": "_city_id", "querystring": "_facet=_city_id"}, - {"name": "_neighborhood", "querystring": "_facet=_neighborhood"}, - {"name": "tags", "querystring": "_facet=tags"}, - {"name": "complex_array", "querystring": "_facet=complex_array"}, - {"name": "created", "querystring": "_facet_date=created"}, + {"name": "created", "querystring": "_extra=suggested_facets&_facet=created"}, + { + "name": "planet_int", + "querystring": "_extra=suggested_facets&_facet=planet_int", + }, + {"name": "on_earth", "querystring": "_extra=suggested_facets&_facet=on_earth"}, + {"name": "state", "querystring": "_extra=suggested_facets&_facet=state"}, + {"name": "_city_id", "querystring": "_extra=suggested_facets&_facet=_city_id"}, + { + "name": "_neighborhood", + "querystring": "_extra=suggested_facets&_facet=_neighborhood", + }, + {"name": "tags", "querystring": "_extra=suggested_facets&_facet=tags"}, + { + "name": "complex_array", + "querystring": "_extra=suggested_facets&_facet=complex_array", + }, + { + "name": "created", + "querystring": "_extra=suggested_facets&_facet_date=created", + }, ] if detect_json1(): - expected.append({"name": "tags", "querystring": "_facet_array=tags"}) + expected.append( + {"name": "tags", "querystring": "_extra=suggested_facets&_facet_array=tags"} + ) assert expected == suggestions def test_allow_facet_off(): with make_app_client(settings={"allow_facet": False}) as client: - assert 400 == client.get("/fixtures/facetable.json?_facet=planet_int").status + assert ( + client.get( + "/fixtures/facetable.json?_facet=planet_int&_extra=suggested_facets" + ).status + == 400 + ) + data = client.get("/fixtures/facetable.json?_extra=suggested_facets").json # Should not suggest any facets either: - assert [] == client.get("/fixtures/facetable.json").json["suggested_facets"] + assert [] == data["suggested_facets"] def test_suggest_facets_off(): with make_app_client(settings={"suggest_facets": False}) as client: # Now suggested_facets should be [] - assert [] == client.get("/fixtures/facetable.json").json["suggested_facets"] + assert ( + [] + == client.get("/fixtures/facetable.json?_extra=suggested_facets").json[ + "suggested_facets" + ] + ) @pytest.mark.asyncio @pytest.mark.parametrize("nofacet", (True, False)) async def test_nofacet(ds_client, nofacet): - path = "/fixtures/facetable.json?_facet=state" + path = "/fixtures/facetable.json?_facet=state&_extra=suggested_facets" if nofacet: path += "&_nofacet=1" response = await ds_client.get(path) if nofacet: assert response.json()["suggested_facets"] == [] - assert response.json()["facet_results"] == {} + assert response.json()["facet_results"]["results"] == {} else: assert response.json()["suggested_facets"] != [] - assert response.json()["facet_results"] != {} + assert response.json()["facet_results"]["results"] != {} @pytest.mark.asyncio @pytest.mark.parametrize("nosuggest", (True, False)) async def test_nosuggest(ds_client, nosuggest): - path = "/fixtures/facetable.json?_facet=state" + path = "/fixtures/facetable.json?_facet=state&_extra=suggested_facets" if nosuggest: path += "&_nosuggest=1" response = await ds_client.get(path) @@ -993,9 +1056,9 @@ async def test_nosuggest(ds_client, nosuggest): @pytest.mark.asyncio @pytest.mark.parametrize("nocount,expected_count", ((True, None), (False, 15))) async def test_nocount(ds_client, nocount, expected_count): - path = "/fixtures/facetable.json" + path = "/fixtures/facetable.json?_extra=count" if nocount: - path += "?_nocount=1" + path += "&_nocount=1" response = await ds_client.get(path) assert response.json()["count"] == expected_count @@ -1280,7 +1343,7 @@ def test_generated_columns_are_visible_in_datasette(): ), ) async def test_col_nocol(ds_client, path, expected_columns): - response = await ds_client.get(path) + response = await ds_client.get(path + "&_extra=columns") assert response.status_code == 200 columns = response.json()["columns"] assert columns == expected_columns diff --git a/tests/test_table_html.py b/tests/test_table_html.py index 857342c3..e1886dab 100644 --- a/tests/test_table_html.py +++ b/tests/test_table_html.py @@ -1160,6 +1160,13 @@ async def test_table_page_title(ds_client, path, expected): assert title == expected +@pytest.mark.asyncio +async def test_table_post_method_not_allowed(ds_client): + response = await ds_client.post("/fixtures/facetable") + assert response.status_code == 405 + assert "Method not allowed" in response.text + + @pytest.mark.parametrize("allow_facet", (True, False)) def test_allow_facet_off(allow_facet): with make_app_client(settings={"allow_facet": allow_facet}) as client: From 3feed1f66e2b746f349ee56970a62246a18bb164 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 22 Mar 2023 15:54:35 -0700 Subject: [PATCH 072/603] Re-applied Black --- datasette/views/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasette/views/table.py b/datasette/views/table.py index 0a6203f2..8c133c26 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -1745,4 +1745,4 @@ async def _next_value_and_url( next_url = datasette.absolute_url( request, datasette.urls.path(path_with_replaced_args(request, added_args)) ) - return next_value, next_url \ No newline at end of file + return next_value, next_url From 5c1cfa451d78e3935193f5e10eba59bf741241a1 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 26 Mar 2023 16:23:28 -0700 Subject: [PATCH 073/603] Link docs /latest/ to /stable/ again Re-implementing the pattern from https://til.simonwillison.net/readthedocs/link-from-latest-to-stable Refs #1608 --- docs/_templates/base.html | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/_templates/base.html b/docs/_templates/base.html index 969de5ab..faa268ef 100644 --- a/docs/_templates/base.html +++ b/docs/_templates/base.html @@ -3,4 +3,29 @@ {% block site_meta %} {{ super() }} + {% endblock %} From db8cf899e286fbaa0a40f3a9ae8d5aaa1478822e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 26 Mar 2023 16:27:58 -0700 Subject: [PATCH 074/603] Use block scripts instead, refs #1608 --- docs/_templates/base.html | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/_templates/base.html b/docs/_templates/base.html index faa268ef..eea82453 100644 --- a/docs/_templates/base.html +++ b/docs/_templates/base.html @@ -3,6 +3,10 @@ {% block site_meta %} {{ super() }} +{% endblock %} + +{% block scripts %} +{{ super() }} + """.format( + markupsafe.escape(ex.sql) + ) + ).strip(), + title="SQL Interrupted", + status=400, + message_is_html=True, + ) + except sqlite3.DatabaseError as ex: + query_error = str(ex) + results = None + rows = [] + columns = [] + except (sqlite3.OperationalError, InvalidSql) as ex: + raise DatasetteError(str(ex), title="Invalid SQL", status=400) + except sqlite3.OperationalError as ex: + raise DatasetteError(str(ex)) + except DatasetteError: + raise + + # Handle formats from plugins + if format_ == "csv": + + async def fetch_data_for_csv(request, _next=None): + results = await db.execute(sql, params, truncate=True) + data = {"rows": results.rows, "columns": results.columns} + return data, None, None + + return await stream_csv(datasette, fetch_data_for_csv, request, db.name) + elif format_ in datasette.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + datasette.renderers[format_][0], + datasette=datasette, + columns=columns, + rows=rows, + sql=sql, + query_name=None, + database=database, + table=None, + request=request, + view_name="table", + truncated=results.truncated if results else False, + error=query_error, + # These will be deprecated in Datasette 1.0: + args=request.args, + data={"rows": rows, "columns": columns}, + ) + if asyncio.iscoroutine(result): + result = await result + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + r = Response( + body=result.get("body"), + status=result.get("status_code") or 200, + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + r = result + # if status_code is not None: + # # Over-ride the status code + # r.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + headers = {} + templates = [f"query-{to_css_class(database)}.html", "query.html"] + template = datasette.jinja_env.select_template(templates) + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + data = {} + headers.update( + { + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + } + ) + metadata = (datasette.metadata("databases") or {}).get(database, {}) + datasette.update_with_inherited_metadata(metadata) + + renderers = {} + for key, (_, can_render) in datasette.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=datasette, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=database, + table=data.get("table"), + request=request, + view_name="database", + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = datasette.urls.path( + path_with_format(request=request, format=key) + ) + + allow_execute_sql = await datasette.permission_allowed( + request.actor, "execute-sql", database + ) + + show_hide_hidden = "" + if metadata.get("hide_sql"): + if bool(params.get("_show_sql")): + show_hide_link = path_with_removed_args(request, {"_show_sql"}) + show_hide_text = "hide" + show_hide_hidden = '' + else: + show_hide_link = path_with_added_args(request, {"_show_sql": 1}) + show_hide_text = "show" + else: + if bool(params.get("_hide_sql")): + show_hide_link = path_with_removed_args(request, {"_hide_sql"}) + show_hide_text = "show" + show_hide_hidden = '' + else: + show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) + show_hide_text = "hide" + hide_sql = show_hide_text == "show" + + r = Response.html( + await datasette.render_template( + template, + QueryContext( + database=database, + query={ + "sql": sql, + # TODO: Params? + }, + canned_query=None, + private=private, + canned_write=False, + db_is_immutable=not db.is_mutable, + error=query_error, + hide_sql=hide_sql, + show_hide_link=datasette.urls.path(show_hide_link), + show_hide_text=show_hide_text, + editable=True, # TODO + allow_execute_sql=allow_execute_sql, + tables=await get_tables(datasette, request, db), + named_parameter_values={}, # TODO + edit_sql_url="todo", + display_rows=await display_rows( + datasette, database, request, rows, columns + ), + table_columns=await _table_columns(datasette, database) + if allow_execute_sql + else {}, + columns=columns, + renderers=renderers, + url_csv=datasette.urls.path( + path_with_format( + request=request, format="csv", extra_qs={"_size": "max"} + ) + ), + show_hide_hidden=markupsafe.Markup(show_hide_hidden), + metadata=metadata, + database_color=lambda _: "#ff0000", + alternate_url_json=alternate_url_json, + ), + request=request, + view_name="database", + ), + headers=headers, + ) + else: + assert False, "Invalid format: {}".format(format_) + if datasette.cors: + add_cors_headers(r.headers) + return r + + class QueryView(DataView): async def data( self, @@ -404,7 +752,7 @@ class QueryView(DataView): display_value = plugin_display_value else: if value in ("", None): - display_value = Markup(" ") + display_value = markupsafe.Markup(" ") elif is_url(str(display_value).strip()): display_value = markupsafe.Markup( '{truncated_url}'.format( @@ -755,3 +1103,69 @@ async def _table_columns(datasette, database_name): for view_name in await db.view_names(): table_columns[view_name] = [] return table_columns + + +async def display_rows(datasette, database, request, rows, columns): + display_rows = [] + truncate_cells = datasette.setting("truncate_cells_html") + for row in rows: + display_row = [] + for column, value in zip(columns, row): + display_value = value + # Let the plugins have a go + # pylint: disable=no-member + plugin_display_value = None + for candidate in pm.hook.render_cell( + row=row, + value=value, + column=column, + table=None, + database=database, + datasette=datasette, + request=request, + ): + candidate = await await_me_maybe(candidate) + if candidate is not None: + plugin_display_value = candidate + break + if plugin_display_value is not None: + display_value = plugin_display_value + else: + if value in ("", None): + display_value = markupsafe.Markup(" ") + elif is_url(str(display_value).strip()): + display_value = markupsafe.Markup( + '{truncated_url}'.format( + url=markupsafe.escape(value.strip()), + truncated_url=markupsafe.escape( + truncate_url(value.strip(), truncate_cells) + ), + ) + ) + elif isinstance(display_value, bytes): + blob_url = path_with_format( + request=request, + format="blob", + extra_qs={ + "_blob_column": column, + "_blob_hash": hashlib.sha256(display_value).hexdigest(), + }, + ) + formatted = format_bytes(len(value)) + display_value = markupsafe.Markup( + '<Binary: {:,} byte{}>'.format( + blob_url, + ' title="{}"'.format(formatted) + if "bytes" not in formatted + else "", + len(value), + "" if len(value) == 1 else "s", + ) + ) + else: + display_value = str(value) + if truncate_cells and len(display_value) > truncate_cells: + display_value = display_value[:truncate_cells] + "\u2026" + display_row.append(display_value) + display_rows.append(display_row) + return display_rows diff --git a/datasette/views/table.py b/datasette/views/table.py index c102c103..77acfd95 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -833,6 +833,8 @@ async def table_view_traced(datasette, request): table=resolved.table, request=request, view_name="table", + truncated=False, + error=None, # These will be deprecated in Datasette 1.0: args=request.args, data=data, diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index 97306529..9bbe6fc6 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -516,6 +516,12 @@ When a request is received, the ``"render"`` callback function is called with ze ``request`` - :ref:`internals_request` The current HTTP request. +``error`` - string or None + If an error occurred this string will contain the error message. + +``truncated`` - bool or None + If the query response was truncated - for example a SQL query returning more than 1,000 results where pagination is not available - this will be ``True``. + ``view_name`` - string The name of the current view being called. ``index``, ``database``, ``table``, and ``row`` are the most important ones. diff --git a/tests/test_api.py b/tests/test_api.py index 40a3e2b8..28415a0b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -638,22 +638,21 @@ def test_database_page_for_database_with_dot_in_name(app_client_with_dot): @pytest.mark.asyncio async def test_custom_sql(ds_client): response = await ds_client.get( - "/fixtures.json?sql=select+content+from+simple_primary_key&_shape=objects" + "/fixtures.json?sql=select+content+from+simple_primary_key" ) data = response.json() - assert {"sql": "select content from simple_primary_key", "params": {}} == data[ - "query" - ] - assert [ - {"content": "hello"}, - {"content": "world"}, - {"content": ""}, - {"content": "RENDER_CELL_DEMO"}, - {"content": "RENDER_CELL_ASYNC"}, - ] == data["rows"] - assert ["content"] == data["columns"] - assert "fixtures" == data["database"] - assert not data["truncated"] + assert data == { + "rows": [ + {"content": "hello"}, + {"content": "world"}, + {"content": ""}, + {"content": "RENDER_CELL_DEMO"}, + {"content": "RENDER_CELL_ASYNC"}, + ], + "columns": ["content"], + "ok": True, + "truncated": False, + } def test_sql_time_limit(app_client_shorter_time_limit): diff --git a/tests/test_cli_serve_get.py b/tests/test_cli_serve_get.py index ac44e1e2..2e0390bb 100644 --- a/tests/test_cli_serve_get.py +++ b/tests/test_cli_serve_get.py @@ -36,7 +36,6 @@ def test_serve_with_get(tmp_path_factory): ) assert 0 == result.exit_code, result.output assert { - "database": "_memory", "truncated": False, "columns": ["sqlite_version()"], }.items() <= json.loads(result.output).items() diff --git a/tests/test_html.py b/tests/test_html.py index eadbd720..6c3860d7 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -248,6 +248,9 @@ async def test_css_classes_on_body(ds_client, path, expected_classes): assert classes == expected_classes +templates_considered_re = re.compile(r"") + + @pytest.mark.asyncio @pytest.mark.parametrize( "path,expected_considered", @@ -271,7 +274,10 @@ async def test_css_classes_on_body(ds_client, path, expected_classes): async def test_templates_considered(ds_client, path, expected_considered): response = await ds_client.get(path) assert response.status_code == 200 - assert f"" in response.text + match = templates_considered_re.search(response.text) + assert match, "No templates considered comment found" + actual_considered = match.group(1) + assert actual_considered == expected_considered @pytest.mark.asyncio diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index 3d5bb2da..d59ff729 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -1,10 +1,12 @@ """ Tests for the datasette.app.Datasette class """ -from datasette import Forbidden +import dataclasses +from datasette import Forbidden, Context from datasette.app import Datasette, Database from itsdangerous import BadSignature import pytest +from typing import Optional @pytest.fixture @@ -136,6 +138,22 @@ async def test_datasette_render_template_no_request(): assert "Error " in rendered +@pytest.mark.asyncio +async def test_datasette_render_template_with_dataclass(): + @dataclasses.dataclass + class ExampleContext(Context): + title: str + status: int + error: str + + context = ExampleContext(title="Hello", status=200, error="Error message") + ds = Datasette(memory=True) + await ds.invoke_startup() + rendered = await ds.render_template("error.html", context) + assert "

Hello

" in rendered + assert "Error message" in rendered + + def test_datasette_error_if_string_not_list(tmpdir): # https://github.com/simonw/datasette/issues/1985 db_path = str(tmpdir / "data.db") diff --git a/tests/test_messages.py b/tests/test_messages.py index 8417b9ae..a7e4d046 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -12,7 +12,7 @@ import pytest ], ) async def test_add_message_sets_cookie(ds_client, qs, expected): - response = await ds_client.get(f"/fixtures.message?{qs}") + response = await ds_client.get(f"/fixtures.message?sql=select+1&{qs}") signed = response.cookies["ds_messages"] decoded = ds_client.ds.unsign(signed, "messages") assert expected == decoded @@ -21,7 +21,9 @@ async def test_add_message_sets_cookie(ds_client, qs, expected): @pytest.mark.asyncio async def test_messages_are_displayed_and_cleared(ds_client): # First set the message cookie - set_msg_response = await ds_client.get("/fixtures.message?add_msg=xmessagex") + set_msg_response = await ds_client.get( + "/fixtures.message?sql=select+1&add_msg=xmessagex" + ) # Now access a page that displays messages response = await ds_client.get("/", cookies=set_msg_response.cookies) # Messages should be in that HTML diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 6971bbf7..28fe720f 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -121,9 +121,8 @@ async def test_hook_extra_css_urls(ds_client, path, expected_decoded_object): ][0]["href"] # This link has a base64-encoded JSON blob in it encoded = special_href.split("/")[3] - assert expected_decoded_object == json.loads( - base64.b64decode(encoded).decode("utf8") - ) + actual_decoded_object = json.loads(base64.b64decode(encoded).decode("utf8")) + assert expected_decoded_object == actual_decoded_object @pytest.mark.asyncio diff --git a/tests/test_table_api.py b/tests/test_table_api.py index cd664ffb..46d1c9b8 100644 --- a/tests/test_table_api.py +++ b/tests/test_table_api.py @@ -700,7 +700,6 @@ async def test_max_returned_rows(ds_client): "/fixtures.json?sql=select+content+from+no_primary_key" ) data = response.json() - assert {"sql": "select content from no_primary_key", "params": {}} == data["query"] assert data["truncated"] assert 100 == len(data["rows"]) From cd57b0f71234273156cb1eba3f9153b9e27ac14d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 8 Aug 2023 06:45:04 -0700 Subject: [PATCH 117/603] Brought back parameter fields, closes #2132 --- datasette/views/database.py | 19 +++++++++++++++++-- tests/test_html.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index 77f3f5b0..0770a380 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -506,6 +506,21 @@ async def query_view( show_hide_text = "hide" hide_sql = show_hide_text == "show" + # Extract any :named parameters + named_parameters = await derive_named_parameters( + datasette.get_database(database), sql + ) + named_parameter_values = { + named_parameter: params.get(named_parameter) or "" + for named_parameter in named_parameters + if not named_parameter.startswith("_") + } + + # Set to blank string if missing from params + for named_parameter in named_parameters: + if named_parameter not in params and not named_parameter.startswith("_"): + params[named_parameter] = "" + r = Response.html( await datasette.render_template( template, @@ -513,7 +528,7 @@ async def query_view( database=database, query={ "sql": sql, - # TODO: Params? + "params": params, }, canned_query=None, private=private, @@ -526,7 +541,7 @@ async def query_view( editable=True, # TODO allow_execute_sql=allow_execute_sql, tables=await get_tables(datasette, request, db), - named_parameter_values={}, # TODO + named_parameter_values=named_parameter_values, edit_sql_url="todo", display_rows=await display_rows( datasette, database, request, rows, columns diff --git a/tests/test_html.py b/tests/test_html.py index 6c3860d7..7856bc27 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -295,6 +295,22 @@ async def test_query_json_csv_export_links(ds_client): assert 'CSV' in response.text +@pytest.mark.asyncio +async def test_query_parameter_form_fields(ds_client): + response = await ds_client.get("/fixtures?sql=select+:name") + assert response.status_code == 200 + assert ( + ' ' + in response.text + ) + response2 = await ds_client.get("/fixtures?sql=select+:name&name=hello") + assert response2.status_code == 200 + assert ( + ' ' + in response2.text + ) + + @pytest.mark.asyncio async def test_row_html_simple_primary_key(ds_client): response = await ds_client.get("/fixtures/simple_primary_key/1") From 26be9f0445b753fb84c802c356b0791a72269f25 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 08:26:52 -0700 Subject: [PATCH 118/603] Refactored canned query code, replaced old QueryView, closes #2114 --- datasette/templates/query.html | 10 +- datasette/views/database.py | 840 +++++++++++++-------------------- datasette/views/table.py | 60 +-- tests/test_canned_queries.py | 8 +- 4 files changed, 345 insertions(+), 573 deletions(-) diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 7ffc250a..fc3b8527 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -24,7 +24,7 @@ {% block content %} -{% if canned_write and db_is_immutable %} +{% if canned_query_write and db_is_immutable %}

This query cannot be executed because the database is immutable.

{% endif %} @@ -32,7 +32,7 @@ {% block description_source_license %}{% include "_description_source_license.html" %}{% endblock %} - +

Custom SQL query{% if display_rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(display_rows|length) }} row{% if display_rows|length == 1 %}{% else %}s{% endif %}{% endif %}{% if not query_error %} ({{ show_hide_text }}) {% endif %}

@@ -61,8 +61,8 @@ {% endif %}

{% if not hide_sql %}{% endif %} - {% if canned_write %}{% endif %} - + {% if canned_query_write %}{% endif %} + {{ show_hide_hidden }} {% if canned_query and edit_sql_url %}Edit SQL{% endif %}

@@ -87,7 +87,7 @@ {% else %} - {% if not canned_write and not error %} + {% if not canned_query_write and not error %}

0 results

{% endif %} {% endif %} diff --git a/datasette/views/database.py b/datasette/views/database.py index 0770a380..658c35e6 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,4 +1,3 @@ -from asyncinject import Registry from dataclasses import dataclass, field from typing import Callable from urllib.parse import parse_qsl, urlencode @@ -33,7 +32,7 @@ from datasette.utils import ( from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden from datasette.plugins import pm -from .base import BaseView, DatasetteError, DataView, View, _error, stream_csv +from .base import BaseView, DatasetteError, View, _error, stream_csv class DatabaseView(View): @@ -57,7 +56,7 @@ class DatabaseView(View): sql = (request.args.get("sql") or "").strip() if sql: - return await query_view(request, datasette) + return await QueryView()(request, datasette) if format_ not in ("html", "json"): raise NotFound("Invalid format: {}".format(format_)) @@ -65,10 +64,6 @@ class DatabaseView(View): metadata = (datasette.metadata("databases") or {}).get(database, {}) datasette.update_with_inherited_metadata(metadata) - table_counts = await db.table_counts(5) - hidden_table_names = set(await db.hidden_table_names()) - all_foreign_keys = await db.get_all_foreign_keys() - sql_views = [] for view_name in await db.view_names(): view_visible, view_private = await datasette.check_visibility( @@ -196,8 +191,13 @@ class QueryContext: # urls: dict = field( # metadata={"help": "Object containing URL helpers like `database()`"} # ) - canned_write: bool = field( - metadata={"help": "Boolean indicating if this canned query allows writes"} + canned_query_write: bool = field( + metadata={ + "help": "Boolean indicating if this is a canned query that allows writes" + } + ) + metadata: dict = field( + metadata={"help": "Metadata about the database or the canned query"} ) db_is_immutable: bool = field( metadata={"help": "Boolean indicating if this database is immutable"} @@ -232,7 +232,6 @@ class QueryContext: show_hide_hidden: str = field( metadata={"help": "Hidden input field for the _show_sql parameter"} ) - metadata: dict = field(metadata={"help": "Metadata about the query/database"}) database_color: Callable = field( metadata={"help": "Function that returns a color for a given database name"} ) @@ -242,6 +241,12 @@ class QueryContext: alternate_url_json: str = field( metadata={"help": "URL for alternate JSON version of this page"} ) + # TODO: refactor this to somewhere else, probably ds.render_template() + select_templates: list = field( + metadata={ + "help": "List of templates that were considered for rendering this page" + } + ) async def get_tables(datasette, request, db): @@ -320,287 +325,105 @@ async def database_download(request, datasette): ) -async def query_view( - request, - datasette, - # canned_query=None, - # _size=None, - # named_parameters=None, - # write=False, -): - db = await datasette.resolve_database(request) - database = db.name - # Flattened because of ?sql=&name1=value1&name2=value2 feature - params = {key: request.args.get(key) for key in request.args} - sql = None - if "sql" in params: - sql = params.pop("sql") - if "_shape" in params: - params.pop("_shape") +class QueryView(View): + async def post(self, request, datasette): + from datasette.app import TableNotFound - # extras come from original request.args to avoid being flattened - extras = request.args.getlist("_extra") + db = await datasette.resolve_database(request) - # TODO: Behave differently for canned query here: - await datasette.ensure_permissions(request.actor, [("execute-sql", database)]) - - _, private = await datasette.check_visibility( - request.actor, - permissions=[ - ("view-database", database), - "view-instance", - ], - ) - - extra_args = {} - if params.get("_timelimit"): - extra_args["custom_time_limit"] = int(params["_timelimit"]) - - format_ = request.url_vars.get("format") or "html" - query_error = None - try: - validate_sql_select(sql) - results = await datasette.execute( - database, sql, params, truncate=True, **extra_args - ) - columns = results.columns - rows = results.rows - except QueryInterrupted as ex: - raise DatasetteError( - textwrap.dedent( - """ -

SQL query took too long. The time limit is controlled by the - sql_time_limit_ms - configuration option.

- - - """.format( - markupsafe.escape(ex.sql) - ) - ).strip(), - title="SQL Interrupted", - status=400, - message_is_html=True, - ) - except sqlite3.DatabaseError as ex: - query_error = str(ex) - results = None - rows = [] - columns = [] - except (sqlite3.OperationalError, InvalidSql) as ex: - raise DatasetteError(str(ex), title="Invalid SQL", status=400) - except sqlite3.OperationalError as ex: - raise DatasetteError(str(ex)) - except DatasetteError: - raise - - # Handle formats from plugins - if format_ == "csv": - - async def fetch_data_for_csv(request, _next=None): - results = await db.execute(sql, params, truncate=True) - data = {"rows": results.rows, "columns": results.columns} - return data, None, None - - return await stream_csv(datasette, fetch_data_for_csv, request, db.name) - elif format_ in datasette.renderers.keys(): - # Dispatch request to the correct output format renderer - # (CSV is not handled here due to streaming) - result = call_with_supported_arguments( - datasette.renderers[format_][0], - datasette=datasette, - columns=columns, - rows=rows, - sql=sql, - query_name=None, - database=database, - table=None, - request=request, - view_name="table", - truncated=results.truncated if results else False, - error=query_error, - # These will be deprecated in Datasette 1.0: - args=request.args, - data={"rows": rows, "columns": columns}, - ) - if asyncio.iscoroutine(result): - result = await result - if result is None: - raise NotFound("No data") - if isinstance(result, dict): - r = Response( - body=result.get("body"), - status=result.get("status_code") or 200, - content_type=result.get("content_type", "text/plain"), - headers=result.get("headers"), + # We must be a canned query + table_found = False + try: + await datasette.resolve_table(request) + table_found = True + except TableNotFound as table_not_found: + canned_query = await datasette.get_canned_query( + table_not_found.database_name, table_not_found.table, request.actor ) - elif isinstance(result, Response): - r = result - # if status_code is not None: - # # Over-ride the status code - # r.status = status_code - else: - assert False, f"{result} should be dict or Response" - elif format_ == "html": - headers = {} - templates = [f"query-{to_css_class(database)}.html", "query.html"] - template = datasette.jinja_env.select_template(templates) - alternate_url_json = datasette.absolute_url( - request, - datasette.urls.path(path_with_format(request=request, format="json")), - ) - data = {} - headers.update( - { - "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( - alternate_url_json - ) - } - ) - metadata = (datasette.metadata("databases") or {}).get(database, {}) - datasette.update_with_inherited_metadata(metadata) + if canned_query is None: + raise + if table_found: + # That should not have happened + raise DatasetteError("Unexpected table found on POST", status=404) - renderers = {} - for key, (_, can_render) in datasette.renderers.items(): - it_can_render = call_with_supported_arguments( - can_render, - datasette=datasette, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name="database", + # If database is immutable, return an error + if not db.is_mutable: + raise Forbidden("Database is immutable") + + # Process the POST + body = await request.post_body() + body = body.decode("utf-8").strip() + if body.startswith("{") and body.endswith("}"): + params = json.loads(body) + # But we want key=value strings + for key, value in params.items(): + params[key] = str(value) + else: + params = dict(parse_qsl(body, keep_blank_values=True)) + # Should we return JSON? + should_return_json = ( + request.headers.get("accept") == "application/json" + or request.args.get("_json") + or params.get("_json") + ) + params_for_query = MagicParameters(params, request, datasette) + ok = None + redirect_url = None + try: + cursor = await db.execute_write(canned_query["sql"], params_for_query) + message = canned_query.get( + "on_success_message" + ) or "Query executed, {} row{} affected".format( + cursor.rowcount, "" if cursor.rowcount == 1 else "s" + ) + message_type = datasette.INFO + redirect_url = canned_query.get("on_success_redirect") + ok = True + except Exception as ex: + message = canned_query.get("on_error_message") or str(ex) + message_type = datasette.ERROR + redirect_url = canned_query.get("on_error_redirect") + ok = False + if should_return_json: + return Response.json( + { + "ok": ok, + "message": message, + "redirect": redirect_url, + } ) - it_can_render = await await_me_maybe(it_can_render) - if it_can_render: - renderers[key] = datasette.urls.path( - path_with_format(request=request, format=key) - ) - - allow_execute_sql = await datasette.permission_allowed( - request.actor, "execute-sql", database - ) - - show_hide_hidden = "" - if metadata.get("hide_sql"): - if bool(params.get("_show_sql")): - show_hide_link = path_with_removed_args(request, {"_show_sql"}) - show_hide_text = "hide" - show_hide_hidden = '' - else: - show_hide_link = path_with_added_args(request, {"_show_sql": 1}) - show_hide_text = "show" else: - if bool(params.get("_hide_sql")): - show_hide_link = path_with_removed_args(request, {"_hide_sql"}) - show_hide_text = "show" - show_hide_hidden = '' - else: - show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) - show_hide_text = "hide" - hide_sql = show_hide_text == "show" + datasette.add_message(request, message, message_type) + return Response.redirect(redirect_url or request.path) - # Extract any :named parameters - named_parameters = await derive_named_parameters( - datasette.get_database(database), sql - ) - named_parameter_values = { - named_parameter: params.get(named_parameter) or "" - for named_parameter in named_parameters - if not named_parameter.startswith("_") - } + async def get(self, request, datasette): + from datasette.app import TableNotFound - # Set to blank string if missing from params - for named_parameter in named_parameters: - if named_parameter not in params and not named_parameter.startswith("_"): - params[named_parameter] = "" - - r = Response.html( - await datasette.render_template( - template, - QueryContext( - database=database, - query={ - "sql": sql, - "params": params, - }, - canned_query=None, - private=private, - canned_write=False, - db_is_immutable=not db.is_mutable, - error=query_error, - hide_sql=hide_sql, - show_hide_link=datasette.urls.path(show_hide_link), - show_hide_text=show_hide_text, - editable=True, # TODO - allow_execute_sql=allow_execute_sql, - tables=await get_tables(datasette, request, db), - named_parameter_values=named_parameter_values, - edit_sql_url="todo", - display_rows=await display_rows( - datasette, database, request, rows, columns - ), - table_columns=await _table_columns(datasette, database) - if allow_execute_sql - else {}, - columns=columns, - renderers=renderers, - url_csv=datasette.urls.path( - path_with_format( - request=request, format="csv", extra_qs={"_size": "max"} - ) - ), - show_hide_hidden=markupsafe.Markup(show_hide_hidden), - metadata=metadata, - database_color=lambda _: "#ff0000", - alternate_url_json=alternate_url_json, - ), - request=request, - view_name="database", - ), - headers=headers, - ) - else: - assert False, "Invalid format: {}".format(format_) - if datasette.cors: - add_cors_headers(r.headers) - return r - - -class QueryView(DataView): - async def data( - self, - request, - sql, - editable=True, - canned_query=None, - metadata=None, - _size=None, - named_parameters=None, - write=False, - default_labels=None, - ): - db = await self.ds.resolve_database(request) + db = await datasette.resolve_database(request) database = db.name - params = {key: request.args.get(key) for key in request.args} - if "sql" in params: - params.pop("sql") - if "_shape" in params: - params.pop("_shape") + + # Are we a canned query? + canned_query = None + canned_query_write = False + if "table" in request.url_vars: + try: + await datasette.resolve_table(request) + except TableNotFound as table_not_found: + # Was this actually a canned query? + canned_query = await datasette.get_canned_query( + table_not_found.database_name, table_not_found.table, request.actor + ) + if canned_query is None: + raise + canned_query_write = bool(canned_query.get("write")) private = False if canned_query: # Respect canned query permissions - visible, private = await self.ds.check_visibility( + visible, private = await datasette.check_visibility( request.actor, permissions=[ - ("view-query", (database, canned_query)), + ("view-query", (database, canned_query["name"])), ("view-database", database), "view-instance", ], @@ -609,18 +432,32 @@ class QueryView(DataView): raise Forbidden("You do not have permission to view this query") else: - await self.ds.ensure_permissions(request.actor, [("execute-sql", database)]) + await datasette.ensure_permissions( + request.actor, [("execute-sql", database)] + ) + + # Flattened because of ?sql=&name1=value1&name2=value2 feature + params = {key: request.args.get(key) for key in request.args} + sql = None + + if canned_query: + sql = canned_query["sql"] + elif "sql" in params: + sql = params.pop("sql") # Extract any :named parameters - named_parameters = named_parameters or await derive_named_parameters( - self.ds.get_database(database), sql - ) + named_parameters = [] + if canned_query and canned_query.get("params"): + named_parameters = canned_query["params"] + if not named_parameters: + named_parameters = await derive_named_parameters( + datasette.get_database(database), sql + ) named_parameter_values = { named_parameter: params.get(named_parameter) or "" for named_parameter in named_parameters if not named_parameter.startswith("_") } - # Set to blank string if missing from params for named_parameter in named_parameters: if named_parameter not in params and not named_parameter.startswith("_"): @@ -629,212 +466,159 @@ class QueryView(DataView): extra_args = {} if params.get("_timelimit"): extra_args["custom_time_limit"] = int(params["_timelimit"]) - if _size: - extra_args["page_size"] = _size - templates = [f"query-{to_css_class(database)}.html", "query.html"] - if canned_query: - templates.insert( - 0, - f"query-{to_css_class(database)}-{to_css_class(canned_query)}.html", - ) + format_ = request.url_vars.get("format") or "html" query_error = None + results = None + rows = [] + columns = [] - # Execute query - as write or as read - if write: - if request.method == "POST": - # If database is immutable, return an error - if not db.is_mutable: - raise Forbidden("Database is immutable") - body = await request.post_body() - body = body.decode("utf-8").strip() - if body.startswith("{") and body.endswith("}"): - params = json.loads(body) - # But we want key=value strings - for key, value in params.items(): - params[key] = str(value) - else: - params = dict(parse_qsl(body, keep_blank_values=True)) - # Should we return JSON? - should_return_json = ( - request.headers.get("accept") == "application/json" - or request.args.get("_json") - or params.get("_json") - ) - if canned_query: - params_for_query = MagicParameters(params, request, self.ds) - else: - params_for_query = params - ok = None - try: - cursor = await self.ds.databases[database].execute_write( - sql, params_for_query - ) - message = metadata.get( - "on_success_message" - ) or "Query executed, {} row{} affected".format( - cursor.rowcount, "" if cursor.rowcount == 1 else "s" - ) - message_type = self.ds.INFO - redirect_url = metadata.get("on_success_redirect") - ok = True - except Exception as e: - message = metadata.get("on_error_message") or str(e) - message_type = self.ds.ERROR - redirect_url = metadata.get("on_error_redirect") - ok = False - if should_return_json: - return Response.json( - { - "ok": ok, - "message": message, - "redirect": redirect_url, - } - ) - else: - self.ds.add_message(request, message, message_type) - return self.redirect(request, redirect_url or request.path) - else: + params_for_query = params - async def extra_template(): - return { - "request": request, - "db_is_immutable": not db.is_mutable, - "path_with_added_args": path_with_added_args, - "path_with_removed_args": path_with_removed_args, - "named_parameter_values": named_parameter_values, - "canned_query": canned_query, - "success_message": request.args.get("_success") or "", - "canned_write": True, - } - - return ( - { - "database": database, - "rows": [], - "truncated": False, - "columns": [], - "query": {"sql": sql, "params": params}, - "private": private, - }, - extra_template, - templates, - ) - else: # Not a write - if canned_query: - params_for_query = MagicParameters(params, request, self.ds) - else: - params_for_query = params + if not canned_query_write: try: - results = await self.ds.execute( + if not canned_query: + # For regular queries we only allow SELECT, plus other rules + validate_sql_select(sql) + else: + # Canned queries can run magic parameters + params_for_query = MagicParameters(params, request, datasette) + results = await datasette.execute( database, sql, params_for_query, truncate=True, **extra_args ) - columns = [r[0] for r in results.description] - except sqlite3.DatabaseError as e: - query_error = e + columns = results.columns + rows = results.rows + except QueryInterrupted as ex: + raise DatasetteError( + textwrap.dedent( + """ +

SQL query took too long. The time limit is controlled by the + sql_time_limit_ms + configuration option.

+ + + """.format( + markupsafe.escape(ex.sql) + ) + ).strip(), + title="SQL Interrupted", + status=400, + message_is_html=True, + ) + except sqlite3.DatabaseError as ex: + query_error = str(ex) results = None + rows = [] columns = [] + except (sqlite3.OperationalError, InvalidSql) as ex: + raise DatasetteError(str(ex), title="Invalid SQL", status=400) + except sqlite3.OperationalError as ex: + raise DatasetteError(str(ex)) + except DatasetteError: + raise - allow_execute_sql = await self.ds.permission_allowed( - request.actor, "execute-sql", database - ) + # Handle formats from plugins + if format_ == "csv": - async def extra_template(): - display_rows = [] - truncate_cells = self.ds.setting("truncate_cells_html") - for row in results.rows if results else []: - display_row = [] - for column, value in zip(results.columns, row): - display_value = value - # Let the plugins have a go - # pylint: disable=no-member - plugin_display_value = None - for candidate in pm.hook.render_cell( - row=row, - value=value, - column=column, - table=None, - database=database, - datasette=self.ds, - request=request, - ): - candidate = await await_me_maybe(candidate) - if candidate is not None: - plugin_display_value = candidate - break - if plugin_display_value is not None: - display_value = plugin_display_value - else: - if value in ("", None): - display_value = markupsafe.Markup(" ") - elif is_url(str(display_value).strip()): - display_value = markupsafe.Markup( - '{truncated_url}'.format( - url=markupsafe.escape(value.strip()), - truncated_url=markupsafe.escape( - truncate_url(value.strip(), truncate_cells) - ), - ) - ) - elif isinstance(display_value, bytes): - blob_url = path_with_format( - request=request, - format="blob", - extra_qs={ - "_blob_column": column, - "_blob_hash": hashlib.sha256( - display_value - ).hexdigest(), - }, - ) - formatted = format_bytes(len(value)) - display_value = markupsafe.Markup( - '<Binary: {:,} byte{}>'.format( - blob_url, - ' title="{}"'.format(formatted) - if "bytes" not in formatted - else "", - len(value), - "" if len(value) == 1 else "s", - ) - ) - else: - display_value = str(value) - if truncate_cells and len(display_value) > truncate_cells: - display_value = ( - display_value[:truncate_cells] + "\u2026" - ) - display_row.append(display_value) - display_rows.append(display_row) + async def fetch_data_for_csv(request, _next=None): + results = await db.execute(sql, params, truncate=True) + data = {"rows": results.rows, "columns": results.columns} + return data, None, None - # Show 'Edit SQL' button only if: - # - User is allowed to execute SQL - # - SQL is an approved SELECT statement - # - No magic parameters, so no :_ in the SQL string - edit_sql_url = None - is_validated_sql = False - try: - validate_sql_select(sql) - is_validated_sql = True - except InvalidSql: - pass - if allow_execute_sql and is_validated_sql and ":_" not in sql: - edit_sql_url = ( - self.ds.urls.database(database) - + "?" - + urlencode( - { - **{ - "sql": sql, - }, - **named_parameter_values, - } - ) + return await stream_csv(datasette, fetch_data_for_csv, request, db.name) + elif format_ in datasette.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + datasette.renderers[format_][0], + datasette=datasette, + columns=columns, + rows=rows, + sql=sql, + query_name=canned_query["name"] if canned_query else None, + database=database, + table=None, + request=request, + view_name="table", + truncated=results.truncated if results else False, + error=query_error, + # These will be deprecated in Datasette 1.0: + args=request.args, + data={"rows": rows, "columns": columns}, + ) + if asyncio.iscoroutine(result): + result = await result + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + r = Response( + body=result.get("body"), + status=result.get("status_code") or 200, + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + r = result + # if status_code is not None: + # # Over-ride the status code + # r.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + headers = {} + templates = [f"query-{to_css_class(database)}.html", "query.html"] + if canned_query: + templates.insert( + 0, + f"query-{to_css_class(database)}-{to_css_class(canned_query['name'])}.html", ) + template = datasette.jinja_env.select_template(templates) + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + data = {} + headers.update( + { + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + } + ) + metadata = (datasette.metadata("databases") or {}).get(database, {}) + datasette.update_with_inherited_metadata(metadata) + + renderers = {} + for key, (_, can_render) in datasette.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=datasette, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=database, + table=data.get("table"), + request=request, + view_name="database", + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = datasette.urls.path( + path_with_format(request=request, format=key) + ) + + allow_execute_sql = await datasette.permission_allowed( + request.actor, "execute-sql", database + ) + show_hide_hidden = "" - if metadata.get("hide_sql"): + if canned_query and canned_query.get("hide_sql"): if bool(params.get("_show_sql")): show_hide_link = path_with_removed_args(request, {"_show_sql"}) show_hide_text = "hide" @@ -855,42 +639,86 @@ class QueryView(DataView): show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) show_hide_text = "hide" hide_sql = show_hide_text == "show" - return { - "display_rows": display_rows, - "custom_sql": True, - "named_parameter_values": named_parameter_values, - "editable": editable, - "canned_query": canned_query, - "edit_sql_url": edit_sql_url, - "metadata": metadata, - "settings": self.ds.settings_dict(), - "request": request, - "show_hide_link": self.ds.urls.path(show_hide_link), - "show_hide_text": show_hide_text, - "show_hide_hidden": markupsafe.Markup(show_hide_hidden), - "hide_sql": hide_sql, - "table_columns": await _table_columns(self.ds, database) - if allow_execute_sql - else {}, - } - return ( - { - "ok": not query_error, - "database": database, - "query_name": canned_query, - "rows": results.rows if results else [], - "truncated": results.truncated if results else False, - "columns": columns, - "query": {"sql": sql, "params": params}, - "error": str(query_error) if query_error else None, - "private": private, - "allow_execute_sql": allow_execute_sql, - }, - extra_template, - templates, - 400 if query_error else 200, - ) + # Show 'Edit SQL' button only if: + # - User is allowed to execute SQL + # - SQL is an approved SELECT statement + # - No magic parameters, so no :_ in the SQL string + edit_sql_url = None + is_validated_sql = False + try: + validate_sql_select(sql) + is_validated_sql = True + except InvalidSql: + pass + if allow_execute_sql and is_validated_sql and ":_" not in sql: + edit_sql_url = ( + datasette.urls.database(database) + + "?" + + urlencode( + { + **{ + "sql": sql, + }, + **named_parameter_values, + } + ) + ) + + r = Response.html( + await datasette.render_template( + template, + QueryContext( + database=database, + query={ + "sql": sql, + "params": params, + }, + canned_query=canned_query["name"] if canned_query else None, + private=private, + canned_query_write=canned_query_write, + db_is_immutable=not db.is_mutable, + error=query_error, + hide_sql=hide_sql, + show_hide_link=datasette.urls.path(show_hide_link), + show_hide_text=show_hide_text, + editable=not canned_query, + allow_execute_sql=allow_execute_sql, + tables=await get_tables(datasette, request, db), + named_parameter_values=named_parameter_values, + edit_sql_url=edit_sql_url, + display_rows=await display_rows( + datasette, database, request, rows, columns + ), + table_columns=await _table_columns(datasette, database) + if allow_execute_sql + else {}, + columns=columns, + renderers=renderers, + url_csv=datasette.urls.path( + path_with_format( + request=request, format="csv", extra_qs={"_size": "max"} + ) + ), + show_hide_hidden=markupsafe.Markup(show_hide_hidden), + metadata=canned_query or metadata, + database_color=lambda _: "#ff0000", + alternate_url_json=alternate_url_json, + select_templates=[ + f"{'*' if template_name == template.name else ''}{template_name}" + for template_name in templates + ], + ), + request=request, + view_name="database", + ), + headers=headers, + ) + else: + assert False, "Invalid format: {}".format(format_) + if datasette.cors: + add_cors_headers(r.headers) + return r class MagicParameters(dict): diff --git a/datasette/views/table.py b/datasette/views/table.py index 77acfd95..28264e92 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -9,7 +9,6 @@ import markupsafe from datasette.plugins import pm from datasette.database import QueryInterrupted from datasette import tracer -from datasette.renderer import json_renderer from datasette.utils import ( add_cors_headers, await_me_maybe, @@ -21,7 +20,6 @@ from datasette.utils import ( tilde_encode, escape_sqlite, filters_should_redirect, - format_bytes, is_url, path_from_row_pks, path_with_added_args, @@ -38,7 +36,7 @@ from datasette.utils import ( from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response from datasette.filters import Filters import sqlite_utils -from .base import BaseView, DataView, DatasetteError, ureg, _error, stream_csv +from .base import BaseView, DatasetteError, ureg, _error, stream_csv from .database import QueryView LINK_WITH_LABEL = ( @@ -698,57 +696,6 @@ async def table_view(datasette, request): return response -class CannedQueryView(DataView): - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - from datasette.app import TableNotFound - - try: - await self.ds.resolve_table(request) - except TableNotFound as e: - # Was this actually a canned query? - canned_query = await self.ds.get_canned_query( - e.database_name, e.table, request.actor - ) - if canned_query: - # Handle POST to a canned query - return await QueryView(self.ds).data( - request, - canned_query["sql"], - metadata=canned_query, - editable=False, - canned_query=e.table, - named_parameters=canned_query.get("params"), - write=bool(canned_query.get("write")), - ) - - return Response.text("Method not allowed", status=405) - - async def data(self, request, **kwargs): - from datasette.app import TableNotFound - - try: - await self.ds.resolve_table(request) - except TableNotFound as not_found: - canned_query = await self.ds.get_canned_query( - not_found.database_name, not_found.table, request.actor - ) - if canned_query: - return await QueryView(self.ds).data( - request, - canned_query["sql"], - metadata=canned_query, - editable=False, - canned_query=not_found.table, - named_parameters=canned_query.get("params"), - write=bool(canned_query.get("write")), - ) - else: - raise - - async def table_view_traced(datasette, request): from datasette.app import TableNotFound @@ -761,10 +708,7 @@ async def table_view_traced(datasette, request): ) # If this is a canned query, not a table, then dispatch to QueryView instead if canned_query: - if request.method == "POST": - return await CannedQueryView(datasette).post(request) - else: - return await CannedQueryView(datasette).get(request) + return await QueryView()(request, datasette) else: raise diff --git a/tests/test_canned_queries.py b/tests/test_canned_queries.py index d6a88733..e9ad3239 100644 --- a/tests/test_canned_queries.py +++ b/tests/test_canned_queries.py @@ -95,12 +95,12 @@ def test_insert(canned_write_client): csrftoken_from=True, cookies={"foo": "bar"}, ) - assert 302 == response.status - assert "/data/add_name?success" == response.headers["Location"] messages = canned_write_client.ds.unsign( response.cookies["ds_messages"], "messages" ) - assert [["Query executed, 1 row affected", 1]] == messages + assert messages == [["Query executed, 1 row affected", 1]] + assert response.status == 302 + assert response.headers["Location"] == "/data/add_name?success" @pytest.mark.parametrize( @@ -382,11 +382,11 @@ def test_magic_parameters_cannot_be_used_in_arbitrary_queries(magic_parameters_c def test_canned_write_custom_template(canned_write_client): response = canned_write_client.get("/data/update_name") assert response.status == 200 + assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text assert ( "" in response.text ) - assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text # And test for link rel=alternate while we're here: assert ( '' From 8920d425f4d417cfd998b61016c5ff3530cd34e1 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 10:20:58 -0700 Subject: [PATCH 119/603] 1.0a3 release notes, smaller changes section - refs #2135 --- docs/changelog.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index ee48d075..b4416f94 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,25 @@ Changelog ========= +.. _v1_0_a3: + +1.0a3 (2023-08-09) +------------------ + +This alpha release previews the updated design for Datasette's default JSON API. + +Smaller changes +~~~~~~~~~~~~~~~ + +- Datasette documentation now shows YAML examples for :ref:`metadata` by default, with a tab interface for switching to JSON. (:issue:`1153`) +- :ref:`plugin_register_output_renderer` plugins now have access to ``error`` and ``truncated`` arguments, allowing them to display error messages and take into account truncated results. (:issue:`2130`) +- ``render_cell()`` plugin hook now also supports an optional ``request`` argument. (:issue:`2007`) +- New ``Justfile`` to support development workflows for Datasette using `Just `__. +- ``datasette.render_template()`` can now accepts a ``datasette.views.Context`` subclass as an alternative to a dictionary. (:issue:`2127`) +- ``datasette install -e path`` option for editable installations, useful while developing plugins. (:issue:`2106`) +- When started with the ``--cors`` option Datasette now serves an ``Access-Control-Max-Age: 3600`` header, ensuring CORS OPTIONS requests are repeated no more than once an hour. (:issue:`2079`) +- Fixed a bug where the ``_internal`` database could display ``None`` instead of ``null`` for in-memory databases. (:issue:`1970`) + .. _v0_64_2: 0.64.2 (2023-03-08) From e34d09c6ec16ff5e7717e112afdad67f7c05a62a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 12:01:59 -0700 Subject: [PATCH 120/603] Don't include columns in query JSON, refs #2136 --- datasette/renderer.py | 8 +++++++- datasette/views/database.py | 2 +- tests/test_api.py | 1 - tests/test_cli_serve_get.py | 11 ++++++----- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/datasette/renderer.py b/datasette/renderer.py index 0bd74e81..224031a7 100644 --- a/datasette/renderer.py +++ b/datasette/renderer.py @@ -27,7 +27,7 @@ def convert_specific_columns_to_json(rows, columns, json_cols): return new_rows -def json_renderer(args, data, error, truncated=None): +def json_renderer(request, args, data, error, truncated=None): """Render a response as JSON""" status_code = 200 @@ -106,6 +106,12 @@ def json_renderer(args, data, error, truncated=None): "status": 400, "title": None, } + + # Don't include "columns" in output + # https://github.com/simonw/datasette/issues/2136 + if isinstance(data, dict) and "columns" not in request.args.getlist("_extra"): + data.pop("columns", None) + # Handle _nl option for _shape=array nl = args.get("_nl", "") if nl and shape == "array": diff --git a/datasette/views/database.py b/datasette/views/database.py index 658c35e6..cf76f3c2 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -548,7 +548,7 @@ class QueryView(View): error=query_error, # These will be deprecated in Datasette 1.0: args=request.args, - data={"rows": rows, "columns": columns}, + data={"ok": True, "rows": rows, "columns": columns}, ) if asyncio.iscoroutine(result): result = await result diff --git a/tests/test_api.py b/tests/test_api.py index 28415a0b..f96f571e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -649,7 +649,6 @@ async def test_custom_sql(ds_client): {"content": "RENDER_CELL_DEMO"}, {"content": "RENDER_CELL_ASYNC"}, ], - "columns": ["content"], "ok": True, "truncated": False, } diff --git a/tests/test_cli_serve_get.py b/tests/test_cli_serve_get.py index 2e0390bb..dc7fc1e2 100644 --- a/tests/test_cli_serve_get.py +++ b/tests/test_cli_serve_get.py @@ -34,11 +34,12 @@ def test_serve_with_get(tmp_path_factory): "/_memory.json?sql=select+sqlite_version()", ], ) - assert 0 == result.exit_code, result.output - assert { - "truncated": False, - "columns": ["sqlite_version()"], - }.items() <= json.loads(result.output).items() + assert result.exit_code == 0, result.output + data = json.loads(result.output) + # Should have a single row with a single column + assert len(data["rows"]) == 1 + assert list(data["rows"][0].keys()) == ["sqlite_version()"] + assert set(data.keys()) == {"rows", "ok", "truncated"} # The plugin should have created hello.txt assert (plugins_dir / "hello.txt").read_text() == "hello" From 856ca68d94708c6e94673cb6bc28bf3e3ca17845 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 12:04:40 -0700 Subject: [PATCH 121/603] Update default JSON representation docs, refs #2135 --- docs/json_api.rst | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/docs/json_api.rst b/docs/json_api.rst index c273c2a8..16b997eb 100644 --- a/docs/json_api.rst +++ b/docs/json_api.rst @@ -9,10 +9,10 @@ through the Datasette user interface can also be accessed as JSON via the API. To access the API for a page, either click on the ``.json`` link on that page or edit the URL and add a ``.json`` extension to it. -.. _json_api_shapes: +.. _json_api_default: -Different shapes ----------------- +Default representation +---------------------- The default JSON representation of data from a SQLite table or custom query looks like this: @@ -21,7 +21,6 @@ looks like this: { "ok": true, - "next": null, "rows": [ { "id": 3, @@ -39,13 +38,22 @@ looks like this: "id": 1, "name": "San Francisco" } - ] + ], + "truncated": false } -The ``rows`` key is a list of objects, each one representing a row. ``next`` indicates if -there is another page, and ``ok`` is always ``true`` if an error did not occur. +``"ok"`` is always ``true`` if an error did not occur. -If ``next`` is present then the next page in the pagination set can be retrieved using ``?_next=VALUE``. +The ``"rows"`` key is a list of objects, each one representing a row. + +The ``"truncated"`` key lets you know if the query was truncated. This can happen if a SQL query returns more than 1,000 results (or the :ref:`setting_max_returned_rows` setting). + +For table pages, an additional key ``"next"`` may be present. This indicates that the next page in the pagination set can be retrieved using ``?_next=VALUE``. + +.. _json_api_shapes: + +Different shapes +---------------- The ``_shape`` parameter can be used to access alternative formats for the ``rows`` key which may be more convenient for your application. There are three From 90cb9ca58d910f49e8f117bbdd94df6f0855cf99 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 12:11:16 -0700 Subject: [PATCH 122/603] JSON changes in release notes, refs #2135 --- docs/changelog.rst | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index b4416f94..4c70855b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,7 +9,40 @@ Changelog 1.0a3 (2023-08-09) ------------------ -This alpha release previews the updated design for Datasette's default JSON API. +This alpha release previews the updated design for Datasette's default JSON API. (:issue:`782`) + +The new :ref:`default JSON representation ` for both table pages (``/dbname/table.json``) and arbitrary SQL queries (``/dbname.json?sql=...``) is now shaped like this: + +.. code-block:: json + + { + "ok": true, + "rows": [ + { + "id": 3, + "name": "Detroit" + }, + { + "id": 2, + "name": "Los Angeles" + }, + { + "id": 4, + "name": "Memnonia" + }, + { + "id": 1, + "name": "San Francisco" + } + ], + "truncated": false + } + +Tables will include an additional ``"next"`` key for pagination, which can be passed to ``?_next=`` to fetch the next page of results. + +The various ``?_shape=`` options continue to work as before - see :ref:`json_api_shapes` for details. + +A new ``?_extra=`` mechanism is available for tables, but has not yet been stabilized or documented. Details on that are available in :issue:`262`. Smaller changes ~~~~~~~~~~~~~~~ From 19ab4552e212c9845a59461cc73e82d5ae8c278a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 12:13:11 -0700 Subject: [PATCH 123/603] Release 1.0a3 Closes #2135 Refs #262, #782, #1153, #1970, #2007, #2079, #2106, #2127, #2130 --- datasette/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasette/version.py b/datasette/version.py index 3b81ab21..61dee464 100644 --- a/datasette/version.py +++ b/datasette/version.py @@ -1,2 +1,2 @@ -__version__ = "1.0a2" +__version__ = "1.0a3" __version_info__ = tuple(__version__.split(".")) From 4a42476bb7ce4c5ed941f944115dedd9bce34656 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 15:04:16 -0700 Subject: [PATCH 124/603] datasette plugins --requirements, closes #2133 --- datasette/cli.py | 12 ++++++++++-- docs/cli-reference.rst | 1 + docs/plugins.rst | 32 ++++++++++++++++++++++++++++---- tests/test_cli.py | 3 +++ 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/datasette/cli.py b/datasette/cli.py index 32266888..21fd25d6 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -223,15 +223,23 @@ pm.hook.publish_subcommand(publish=publish) @cli.command() @click.option("--all", help="Include built-in default plugins", is_flag=True) +@click.option( + "--requirements", help="Output requirements.txt of installed plugins", is_flag=True +) @click.option( "--plugins-dir", type=click.Path(exists=True, file_okay=False, dir_okay=True), help="Path to directory containing custom plugins", ) -def plugins(all, plugins_dir): +def plugins(all, requirements, plugins_dir): """List currently installed plugins""" app = Datasette([], plugins_dir=plugins_dir) - click.echo(json.dumps(app._plugins(all=all), indent=4)) + if requirements: + for plugin in app._plugins(): + if plugin["version"]: + click.echo("{}=={}".format(plugin["name"], plugin["version"])) + else: + click.echo(json.dumps(app._plugins(all=all), indent=4)) @cli.command() diff --git a/docs/cli-reference.rst b/docs/cli-reference.rst index 2177fc9e..7a96d311 100644 --- a/docs/cli-reference.rst +++ b/docs/cli-reference.rst @@ -282,6 +282,7 @@ Output JSON showing all currently installed plugins, their versions, whether the Options: --all Include built-in default plugins + --requirements Output requirements.txt of installed plugins --plugins-dir DIRECTORY Path to directory containing custom plugins --help Show this message and exit. diff --git a/docs/plugins.rst b/docs/plugins.rst index 979f94dd..19bfdd0c 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -90,7 +90,12 @@ You can see a list of installed plugins by navigating to the ``/-/plugins`` page You can also use the ``datasette plugins`` command:: - $ datasette plugins + datasette plugins + +Which outputs: + +.. code-block:: json + [ { "name": "datasette_json_html", @@ -107,7 +112,8 @@ You can also use the ``datasette plugins`` command:: cog.out("\n") result = CliRunner().invoke(cli.cli, ["plugins", "--all"]) # cog.out() with text containing newlines was unindenting for some reason - cog.outl("If you run ``datasette plugins --all`` it will include default plugins that ship as part of Datasette::\n") + cog.outl("If you run ``datasette plugins --all`` it will include default plugins that ship as part of Datasette:\n") + cog.outl(".. code-block:: json\n") plugins = [p for p in json.loads(result.output) if p["name"].startswith("datasette.")] indented = textwrap.indent(json.dumps(plugins, indent=4), " ") for line in indented.split("\n"): @@ -115,7 +121,9 @@ You can also use the ``datasette plugins`` command:: cog.out("\n\n") .. ]]] -If you run ``datasette plugins --all`` it will include default plugins that ship as part of Datasette:: +If you run ``datasette plugins --all`` it will include default plugins that ship as part of Datasette: + +.. code-block:: json [ { @@ -236,6 +244,22 @@ If you run ``datasette plugins --all`` it will include default plugins that ship You can add the ``--plugins-dir=`` option to include any plugins found in that directory. +Add ``--requirements`` to output a list of installed plugins that can then be installed in another Datasette instance using ``datasette install -r requirements.txt``:: + + datasette plugins --requirements + +The output will look something like this:: + + datasette-codespaces==0.1.1 + datasette-graphql==2.2 + datasette-json-html==1.0.1 + datasette-pretty-json==0.2.2 + datasette-x-forwarded-host==0.1 + +To write that to a ``requirements.txt`` file, run this:: + + datasette plugins --requirements > requirements.txt + .. _plugins_configuration: Plugin configuration @@ -390,7 +414,7 @@ Any values embedded in ``metadata.yaml`` will be visible to anyone who views the If you are publishing your data using the :ref:`datasette publish ` family of commands, you can use the ``--plugin-secret`` option to set these secrets at publish time. For example, using Heroku you might run the following command:: - $ datasette publish heroku my_database.db \ + datasette publish heroku my_database.db \ --name my-heroku-app-demo \ --install=datasette-auth-github \ --plugin-secret datasette-auth-github client_id your_client_id \ diff --git a/tests/test_cli.py b/tests/test_cli.py index 75724f61..056e2821 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -108,6 +108,9 @@ def test_plugins_cli(app_client): assert set(names).issuperset({p["name"] for p in EXPECTED_PLUGINS}) # And the following too: assert set(names).issuperset(DEFAULT_PLUGINS) + # --requirements should be empty because there are no installed non-plugins-dir plugins + result3 = runner.invoke(cli, ["plugins", "--requirements"]) + assert result3.output == "" def test_metadata_yaml(): From a3593c901580ea50854c3e0774b0ba0126e8a76f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 17:32:07 -0700 Subject: [PATCH 125/603] on_success_message_sql, closes #2138 --- datasette/views/database.py | 29 ++++++++++++++++---- docs/sql_queries.rst | 21 ++++++++++---- tests/test_canned_queries.py | 53 +++++++++++++++++++++++++++++++----- 3 files changed, 85 insertions(+), 18 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index cf76f3c2..79b3f88d 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -360,6 +360,10 @@ class QueryView(View): params[key] = str(value) else: params = dict(parse_qsl(body, keep_blank_values=True)) + + # Don't ever send csrftoken as a SQL parameter + params.pop("csrftoken", None) + # Should we return JSON? should_return_json = ( request.headers.get("accept") == "application/json" @@ -371,12 +375,27 @@ class QueryView(View): redirect_url = None try: cursor = await db.execute_write(canned_query["sql"], params_for_query) - message = canned_query.get( - "on_success_message" - ) or "Query executed, {} row{} affected".format( - cursor.rowcount, "" if cursor.rowcount == 1 else "s" - ) + # success message can come from on_success_message or on_success_message_sql + message = None message_type = datasette.INFO + on_success_message_sql = canned_query.get("on_success_message_sql") + if on_success_message_sql: + try: + message_result = ( + await db.execute(on_success_message_sql, params_for_query) + ).first() + if message_result: + message = message_result[0] + except Exception as ex: + message = "Error running on_success_message_sql: {}".format(ex) + message_type = datasette.ERROR + if not message: + message = canned_query.get( + "on_success_message" + ) or "Query executed, {} row{} affected".format( + cursor.rowcount, "" if cursor.rowcount == 1 else "s" + ) + redirect_url = canned_query.get("on_success_redirect") ok = True except Exception as ex: diff --git a/docs/sql_queries.rst b/docs/sql_queries.rst index 3c2cb228..1ae07e1f 100644 --- a/docs/sql_queries.rst +++ b/docs/sql_queries.rst @@ -392,6 +392,7 @@ This configuration will create a page at ``/mydatabase/add_name`` displaying a f You can customize how Datasette represents success and errors using the following optional properties: - ``on_success_message`` - the message shown when a query is successful +- ``on_success_message_sql`` - alternative to ``on_success_message``: a SQL query that should be executed to generate the message - ``on_success_redirect`` - the path or URL the user is redirected to on success - ``on_error_message`` - the message shown when a query throws an error - ``on_error_redirect`` - the path or URL the user is redirected to on error @@ -405,11 +406,12 @@ For example: "queries": { "add_name": { "sql": "INSERT INTO names (name) VALUES (:name)", + "params": ["name"], "write": True, - "on_success_message": "Name inserted", + "on_success_message_sql": "select 'Name inserted: ' || :name", "on_success_redirect": "/mydatabase/names", "on_error_message": "Name insert failed", - "on_error_redirect": "/mydatabase" + "on_error_redirect": "/mydatabase", } } } @@ -426,8 +428,10 @@ For example: queries: add_name: sql: INSERT INTO names (name) VALUES (:name) + params: + - name write: true - on_success_message: Name inserted + on_success_message_sql: 'select ''Name inserted: '' || :name' on_success_redirect: /mydatabase/names on_error_message: Name insert failed on_error_redirect: /mydatabase @@ -443,8 +447,11 @@ For example: "queries": { "add_name": { "sql": "INSERT INTO names (name) VALUES (:name)", + "params": [ + "name" + ], "write": true, - "on_success_message": "Name inserted", + "on_success_message_sql": "select 'Name inserted: ' || :name", "on_success_redirect": "/mydatabase/names", "on_error_message": "Name insert failed", "on_error_redirect": "/mydatabase" @@ -455,10 +462,12 @@ For example: } .. [[[end]]] -You can use ``"params"`` to explicitly list the named parameters that should be displayed as form fields - otherwise they will be automatically detected. +You can use ``"params"`` to explicitly list the named parameters that should be displayed as form fields - otherwise they will be automatically detected. ``"params"`` is not necessary in the above example, since without it ``"name"`` would be automatically detected from the query. You can pre-populate form fields when the page first loads using a query string, e.g. ``/mydatabase/add_name?name=Prepopulated``. The user will have to submit the form to execute the query. +If you specify a query in ``"on_success_message_sql"``, that query will be executed after the main query. The first column of the first row return by that query will be displayed as a success message. Named parameters from the main query will be made available to the success message query as well. + .. _canned_queries_magic_parameters: Magic parameters @@ -589,7 +598,7 @@ The JSON response will look like this: "redirect": "/data/add_name" } -The ``"message"`` and ``"redirect"`` values here will take into account ``on_success_message``, ``on_success_redirect``, ``on_error_message`` and ``on_error_redirect``, if they have been set. +The ``"message"`` and ``"redirect"`` values here will take into account ``on_success_message``, ``on_success_message_sql``, ``on_success_redirect``, ``on_error_message`` and ``on_error_redirect``, if they have been set. .. _pagination: diff --git a/tests/test_canned_queries.py b/tests/test_canned_queries.py index e9ad3239..5256c24c 100644 --- a/tests/test_canned_queries.py +++ b/tests/test_canned_queries.py @@ -31,9 +31,15 @@ def canned_write_client(tmpdir): }, "add_name_specify_id": { "sql": "insert into names (rowid, name) values (:rowid, :name)", + "on_success_message_sql": "select 'Name added: ' || :name || ' with rowid ' || :rowid", "write": True, "on_error_redirect": "/data/add_name_specify_id?error", }, + "add_name_specify_id_with_error_in_on_success_message_sql": { + "sql": "insert into names (rowid, name) values (:rowid, :name)", + "on_success_message_sql": "select this is bad SQL", + "write": True, + }, "delete_name": { "sql": "delete from names where rowid = :rowid", "write": True, @@ -179,6 +185,34 @@ def test_insert_error(canned_write_client): ) +def test_on_success_message_sql(canned_write_client): + response = canned_write_client.post( + "/data/add_name_specify_id", + {"rowid": 5, "name": "Should be OK"}, + csrftoken_from=True, + ) + assert response.status == 302 + assert response.headers["Location"] == "/data/add_name_specify_id" + messages = canned_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert messages == [["Name added: Should be OK with rowid 5", 1]] + + +def test_error_in_on_success_message_sql(canned_write_client): + response = canned_write_client.post( + "/data/add_name_specify_id_with_error_in_on_success_message_sql", + {"rowid": 1, "name": "Should fail"}, + csrftoken_from=True, + ) + messages = canned_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert messages == [ + ["Error running on_success_message_sql: no such column: bad", 3] + ] + + def test_custom_params(canned_write_client): response = canned_write_client.get("/data/update_name?extra=foo") assert '' in response.text @@ -232,21 +266,22 @@ def test_canned_query_permissions_on_database_page(canned_write_client): query_names = { q["name"] for q in canned_write_client.get("/data.json").json["queries"] } - assert { + assert query_names == { + "add_name_specify_id_with_error_in_on_success_message_sql", + "from_hook", + "update_name", + "add_name_specify_id", + "from_async_hook", "canned_read", "add_name", - "add_name_specify_id", - "update_name", - "from_async_hook", - "from_hook", - } == query_names + } # With auth shows four response = canned_write_client.get( "/data.json", cookies={"ds_actor": canned_write_client.actor_cookie({"id": "root"})}, ) - assert 200 == response.status + assert response.status == 200 query_names_and_private = sorted( [ {"name": q["name"], "private": q["private"]} @@ -257,6 +292,10 @@ def test_canned_query_permissions_on_database_page(canned_write_client): assert query_names_and_private == [ {"name": "add_name", "private": False}, {"name": "add_name_specify_id", "private": False}, + { + "name": "add_name_specify_id_with_error_in_on_success_message_sql", + "private": False, + }, {"name": "canned_read", "private": False}, {"name": "delete_name", "private": True}, {"name": "from_async_hook", "private": False}, From 33251d04e78d575cca62bb59069bb43a7d924746 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 9 Aug 2023 17:56:27 -0700 Subject: [PATCH 126/603] Canned query write counters demo, refs #2134 --- .github/workflows/deploy-latest.yml | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/.github/workflows/deploy-latest.yml b/.github/workflows/deploy-latest.yml index ed60376c..4746aa07 100644 --- a/.github/workflows/deploy-latest.yml +++ b/.github/workflows/deploy-latest.yml @@ -57,6 +57,36 @@ jobs: db.route = "alternative-route" ' > plugins/alternative_route.py cp fixtures.db fixtures2.db + - name: And the counters writable canned query demo + run: | + cat > plugins/counters.py < Date: Thu, 10 Aug 2023 22:16:19 -0700 Subject: [PATCH 127/603] Fixed display of database color Closes #2139, closes #2119 --- datasette/database.py | 7 +++++++ datasette/templates/database.html | 2 +- datasette/templates/query.html | 2 +- datasette/templates/row.html | 2 +- datasette/templates/table.html | 2 +- datasette/views/base.py | 4 ---- datasette/views/database.py | 8 +++----- datasette/views/index.py | 4 +--- datasette/views/row.py | 4 +++- datasette/views/table.py | 2 +- tests/test_html.py | 20 ++++++++++++++++++++ 11 files changed, 39 insertions(+), 18 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index d8043c24..af39ac9e 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,6 +1,7 @@ import asyncio from collections import namedtuple from pathlib import Path +import hashlib import janus import queue import sys @@ -62,6 +63,12 @@ class Database: } return self._cached_table_counts + @property + def color(self): + if self.hash: + return self.hash[:6] + return hashlib.md5(self.name.encode("utf8")).hexdigest()[:6] + def suggest_name(self): if self.path: return Path(self.path).stem diff --git a/datasette/templates/database.html b/datasette/templates/database.html index 7acf0369..3d4dae07 100644 --- a/datasette/templates/database.html +++ b/datasette/templates/database.html @@ -10,7 +10,7 @@ {% block body_class %}db db-{{ database|to_css_class }}{% endblock %} {% block content %} -