From dce718961cc9dbadb7ade1f12ed09074ef746532 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 15 Nov 2024 13:17:45 -0800 Subject: [PATCH] Async support for magic parameters Closes #2441 --- datasette/views/database.py | 25 ++++++++++++++++++++++--- docs/plugin_hooks.rst | 6 +++++- tests/plugins/my_plugin.py | 4 ++++ tests/test_plugins.py | 7 +++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index 61fe15e4..7b081eae 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -391,7 +391,10 @@ class QueryView(View): or request.args.get("_json") or params.get("_json") ) - params_for_query = MagicParameters(params, request, datasette) + params_for_query = MagicParameters( + canned_query["sql"], params, request, datasette + ) + await params_for_query.execute_params() ok = None redirect_url = None try: @@ -523,7 +526,8 @@ class QueryView(View): validate_sql_select(sql) else: # Canned queries can run magic parameters - params_for_query = MagicParameters(params, request, datasette) + params_for_query = MagicParameters(sql, params, request, datasette) + await params_for_query.execute_params() results = await datasette.execute( database, sql, params_for_query, truncate=True, **extra_args ) @@ -792,14 +796,26 @@ class QueryView(View): class MagicParameters(dict): - def __init__(self, data, request, datasette): + def __init__(self, sql, data, request, datasette): super().__init__(data) + self._sql = sql self._request = request self._magics = dict( itertools.chain.from_iterable( pm.hook.register_magic_parameters(datasette=datasette) ) ) + self._prepared = {} + + async def execute_params(self): + for key in derive_named_parameters(self._sql): + if key.startswith("_") and key.count("_") >= 2: + prefix, suffix = key[1:].split("_", 1) + if prefix in self._magics: + result = await await_me_maybe( + self._magics[prefix](suffix, self._request) + ) + self._prepared[key] = result def __len__(self): # Workaround for 'Incorrect number of bindings' error @@ -808,6 +824,9 @@ class MagicParameters(dict): def __getitem__(self, key): if key.startswith("_") and key.count("_") >= 2: + if key in self._prepared: + return self._prepared[key] + # Try the other route prefix, suffix = key[1:].split("_", 1) if prefix in self._magics: try: diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index 5f735a31..a844828f 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -1315,7 +1315,7 @@ Magic parameters all take this format: ``_prefix_rest_of_parameter``. The prefix To register a new function, return it as a tuple of ``(string prefix, function)`` from this hook. The function you register should take two arguments: ``key`` and ``request``, where ``key`` is the ``rest_of_parameter`` portion of the parameter and ``request`` is the current :ref:`internals_request`. -This example registers two new magic parameters: ``:_request_http_version`` returning the HTTP version of the current request, and ``:_uuid_new`` which returns a new UUID: +This example registers two new magic parameters: ``:_request_http_version`` returning the HTTP version of the current request, and ``:_uuid_new`` which returns a new UUID. It also registers an `:_asynclookup_key` parameter, demonstrating that these functions can be asynchronous: .. code-block:: python @@ -1337,11 +1337,15 @@ This example registers two new magic parameters: ``:_request_http_version`` retu raise KeyError + async def asynclookup(key, request): + return await do_something_async(key) + @hookimpl def register_magic_parameters(datasette): return [ ("request", request), ("uuid", uuid), + ("asynclookup", asynclookup), ] .. _plugin_hook_forbidden: diff --git a/tests/plugins/my_plugin.py b/tests/plugins/my_plugin.py index e87353ea..54c59227 100644 --- a/tests/plugins/my_plugin.py +++ b/tests/plugins/my_plugin.py @@ -360,9 +360,13 @@ def register_magic_parameters(): else: raise KeyError + async def asyncrequest(key, request): + return key + return [ ("request", request), ("uuid", uuid), + ("asyncrequest", asyncrequest), ] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index aa8f1578..639e6677 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -857,6 +857,9 @@ def test_hook_register_magic_parameters(restore_working_directory): "get_uuid": { "sql": "select :_uuid_new", }, + "asyncrequest": { + "sql": "select :_asyncrequest_key", + }, } } } @@ -871,6 +874,10 @@ def test_hook_register_magic_parameters(restore_working_directory): assert 200 == response_get.status new_uuid = response_get.json[0][":_uuid_new"] assert 4 == new_uuid.count("-") + # And test the async one + response_async = client.get("/data/asyncrequest.json?_shape=array") + assert 200 == response_async.status + assert response_async.json[0][":_asyncrequest_key"] == "key" def test_hook_forbidden(restore_working_directory):