diff --git a/datasette/app.py b/datasette/app.py index 5f2a484e..a5efdad5 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -2,6 +2,7 @@ from __future__ import annotations from asgi_csrf import Errors import asyncio +import contextvars from typing import TYPE_CHECKING, Any, Dict, Iterable, List if TYPE_CHECKING: @@ -130,6 +131,22 @@ from .resources import DatabaseResource, TableResource app_root = Path(__file__).parent.parent +# Context variable to track when code is executing within a datasette.client request +_in_datasette_client = contextvars.ContextVar("in_datasette_client", default=False) + + +class _DatasetteClientContext: + """Context manager to mark code as executing within a datasette.client request.""" + + def __enter__(self): + self.token = _in_datasette_client.set(True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _in_datasette_client.reset(self.token) + return False + + @dataclasses.dataclass class PermissionCheck: """Represents a logged permission check for debugging purposes.""" @@ -666,6 +683,14 @@ class Datasette: def unsign(self, signed, namespace="default"): return URLSafeSerializer(self._secret, namespace).loads(signed) + def in_client(self) -> bool: + """Check if the current code is executing within a datasette.client request. + + Returns: + bool: True if currently executing within a datasette.client request, False otherwise. + """ + return _in_datasette_client.get() + def create_token( self, actor_id: str, @@ -2406,19 +2431,20 @@ class DatasetteClient: async def _request(self, method, path, skip_permission_checks=False, **kwargs): from datasette.permissions import SkipPermissions - if skip_permission_checks: - with SkipPermissions(): + with _DatasetteClientContext(): + if skip_permission_checks: + with SkipPermissions(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=self.app), + cookies=kwargs.pop("cookies", None), + ) as client: + return await getattr(client, method)(self._fix(path), **kwargs) + else: async with httpx.AsyncClient( transport=httpx.ASGITransport(app=self.app), cookies=kwargs.pop("cookies", None), ) as client: return await getattr(client, method)(self._fix(path), **kwargs) - else: - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await getattr(client, method)(self._fix(path), **kwargs) async def get(self, path, skip_permission_checks=False, **kwargs): return await self._request( @@ -2470,8 +2496,17 @@ class DatasetteClient: from datasette.permissions import SkipPermissions avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) - if skip_permission_checks: - with SkipPermissions(): + with _DatasetteClientContext(): + if skip_permission_checks: + with SkipPermissions(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=self.app), + cookies=kwargs.pop("cookies", None), + ) as client: + return await client.request( + method, self._fix(path, avoid_path_rewrites), **kwargs + ) + else: async with httpx.AsyncClient( transport=httpx.ASGITransport(app=self.app), cookies=kwargs.pop("cookies", None), @@ -2479,11 +2514,3 @@ class DatasetteClient: return await client.request( method, self._fix(path, avoid_path_rewrites), **kwargs ) - else: - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await client.request( - method, self._fix(path, avoid_path_rewrites), **kwargs - ) diff --git a/docs/internals.rst b/docs/internals.rst index 2e01a8e8..09fb7572 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1077,6 +1077,28 @@ This parameter works with all HTTP methods (``get``, ``post``, ``put``, ``patch` Use ``skip_permission_checks=True`` with caution. It completely bypasses Datasette's permission system and should only be used in trusted plugin code or internal operations where you need guaranteed access to resources. +Detecting internal client requests +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``datasette.in_client()`` - returns bool + Returns ``True`` if the current code is executing within a ``datasette.client`` request, ``False`` otherwise. + +This method is useful for plugins that need to behave differently when called through ``datasette.client`` versus when handling external HTTP requests. + +Example usage: + +.. code-block:: python + + async def fetch_documents(datasette): + if not datasette.in_client(): + return Response.text( + "Only available via internal client requests", + status=403 + ) + ... + +Note that ``datasette.in_client()`` is independent of ``skip_permission_checks``. A request made through ``datasette.client`` will always have ``in_client()`` return ``True``, regardless of whether ``skip_permission_checks`` is set. + .. _internals_datasette_urls: datasette.urls diff --git a/tests/test_internals_datasette_client.py b/tests/test_internals_datasette_client.py index a15d294f..b254c5e4 100644 --- a/tests/test_internals_datasette_client.py +++ b/tests/test_internals_datasette_client.py @@ -227,3 +227,89 @@ async def test_skip_permission_checks_shows_denied_tables(): table_names = [match["name"] for match in data["matches"]] # Should see fixtures tables when permission checks are skipped assert "fixtures: test_table" in table_names + + +@pytest.mark.asyncio +async def test_in_client_returns_false_outside_request(datasette): + """Test that datasette.in_client() returns False outside of a client request""" + assert datasette.in_client() is False + + +@pytest.mark.asyncio +async def test_in_client_returns_true_inside_request(): + """Test that datasette.in_client() returns True inside a client request""" + from datasette import hookimpl, Response + from datasette.plugins import pm + + class TestPlugin: + __name__ = "test_in_client_plugin" + + @hookimpl + def register_routes(self): + async def test_view(datasette): + # Assert in_client() returns True within the view + assert datasette.in_client() is True + return Response.json({"in_client": datasette.in_client()}) + + return [ + (r"^/-/test-in-client$", test_view), + ] + + pm.register(TestPlugin(), name="test_in_client_plugin") + try: + ds = Datasette() + await ds.invoke_startup() + + # Outside of a client request, should be False + assert ds.in_client() is False + + # Make a request via datasette.client + response = await ds.client.get("/-/test-in-client") + assert response.status_code == 200 + assert response.json()["in_client"] is True + + # After the request, should be False again + assert ds.in_client() is False + finally: + pm.unregister(name="test_in_client_plugin") + + +@pytest.mark.asyncio +async def test_in_client_with_skip_permission_checks(): + """Test that in_client() works regardless of skip_permission_checks value""" + from datasette import hookimpl + from datasette.plugins import pm + from datasette.utils.asgi import Response + + in_client_values = [] + + class TestPlugin: + __name__ = "test_in_client_skip_plugin" + + @hookimpl + def register_routes(self): + async def test_view(datasette): + in_client_values.append(datasette.in_client()) + return Response.json({"in_client": datasette.in_client()}) + + return [ + (r"^/-/test-in-client$", test_view), + ] + + pm.register(TestPlugin(), name="test_in_client_skip_plugin") + try: + ds = Datasette(config={"databases": {"test_db": {"allow": {"id": "admin"}}}}) + await ds.invoke_startup() + + # Request without skip_permission_checks + await ds.client.get("/-/test-in-client") + # Request with skip_permission_checks=True + await ds.client.get("/-/test-in-client", skip_permission_checks=True) + + # Both should have detected in_client as True + assert ( + len(in_client_values) == 2 + ), f"Expected 2 values, got {len(in_client_values)}" + assert all(in_client_values), f"Expected all True, got {in_client_values}" + finally: + pm.unregister(name="test_in_client_skip_plugin")