diff --git a/datasette/app.py b/datasette/app.py index 6efaa430..d4425cb4 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -713,7 +713,46 @@ class Datasette: """ return _in_datasette_client.get() - def create_token( + def _abbreviate_action(self, action: str) -> str: + """Return the abbreviated form of an action name if one exists.""" + action_obj = self.actions.get(action) + if not action_obj: + return action + return action_obj.abbr or action + + def build_token_restrictions( + self, + restrict_all: Iterable[str] | None = None, + restrict_database: Dict[str, Iterable[str]] | None = None, + restrict_resource: Dict[str, Dict[str, Iterable[str]]] | None = None, + ) -> dict | None: + """Build an abbreviated restrictions dict for use in token payloads. + + Returns a dict like ``{"a": [...], "d": {...}, "r": {...}}`` + or ``None`` if there are no restrictions. + """ + if not (restrict_all or restrict_database or restrict_resource): + return None + restrictions: dict = {} + if restrict_all: + restrictions["a"] = [ + self._abbreviate_action(a) for a in restrict_all + ] + if restrict_database: + restrictions["d"] = { + database: [self._abbreviate_action(a) for a in actions] + for database, actions in restrict_database.items() + } + if restrict_resource: + restrictions["r"] = {} + for database, resources in restrict_resource.items(): + for resource, actions in resources.items(): + restrictions["r"].setdefault(database, {})[resource] = [ + self._abbreviate_action(a) for a in actions + ] + return restrictions + + def create_signed_token( self, actor_id: str, *, @@ -721,37 +760,52 @@ class Datasette: restrict_all: Iterable[str] | None = None, restrict_database: Dict[str, Iterable[str]] | None = None, restrict_resource: Dict[str, Dict[str, Iterable[str]]] | None = None, - ): - token = {"a": actor_id, "t": int(time.time())} + ) -> str: + """Create a signed ``dstok_`` API token. + + This always creates a signed token regardless of installed plugins. + Use :meth:`create_token` to go through the plugin hook instead. + """ + token: dict = {"a": actor_id, "t": int(time.time())} if expires_after: token["d"] = expires_after - - def abbreviate_action(action): - # rename to abbr if possible - action_obj = self.actions.get(action) - if not action_obj: - return action - return action_obj.abbr or action - - if expires_after: - token["d"] = expires_after - if restrict_all or restrict_database or restrict_resource: - token["_r"] = {} - if restrict_all: - token["_r"]["a"] = [abbreviate_action(a) for a in restrict_all] - if restrict_database: - token["_r"]["d"] = {} - for database, actions in restrict_database.items(): - token["_r"]["d"][database] = [abbreviate_action(a) for a in actions] - if restrict_resource: - token["_r"]["r"] = {} - for database, resources in restrict_resource.items(): - for resource, actions in resources.items(): - token["_r"]["r"].setdefault(database, {})[resource] = [ - abbreviate_action(a) for a in actions - ] + restrictions = self.build_token_restrictions( + restrict_all=restrict_all, + restrict_database=restrict_database, + restrict_resource=restrict_resource, + ) + if restrictions: + token["_r"] = restrictions return "dstok_{}".format(self.sign(token, namespace="token")) + async def create_token( + self, + actor_id: str, + *, + expires_after: int | None = None, + restrict_all: Iterable[str] | None = None, + restrict_database: Dict[str, Iterable[str]] | None = None, + restrict_resource: Dict[str, Dict[str, Iterable[str]]] | None = None, + ) -> str: + """Create an API token, dispatching through the ``create_token`` hook. + + If a plugin implements the ``create_token`` hook, it can return + a database-backed token instead of a signed one. The default + implementation creates a signed ``dstok_`` token. + """ + for result in pm.hook.create_token( + datasette=self, + actor_id=actor_id, + expires_after=expires_after, + restrict_all=restrict_all or [], + restrict_database=restrict_database or {}, + restrict_resource=restrict_resource or {}, + ): + result = await await_me_maybe(result) + if result is not None: + return result + raise RuntimeError("No create_token hook implementation returned a token") + def get_database(self, name=None, route=None): if route is not None: matches = [db for db in self.databases.values() if db.route == route] diff --git a/datasette/cli.py b/datasette/cli.py index 121911ab..d105715b 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -841,7 +841,7 @@ def create_token( action ) - token = ds.create_token( + token = ds.create_signed_token( id, expires_after=expires_after, restrict_all=alls, diff --git a/datasette/default_permissions/__init__.py b/datasette/default_permissions/__init__.py index 40373fa7..144fcafe 100644 --- a/datasette/default_permissions/__init__.py +++ b/datasette/default_permissions/__init__.py @@ -38,6 +38,7 @@ from .defaults import ( DEFAULT_ALLOW_ACTIONS as DEFAULT_ALLOW_ACTIONS, ) from .tokens import actor_from_signed_api_token as actor_from_signed_api_token +from .tokens import create_signed_api_token as create_signed_api_token @hookimpl diff --git a/datasette/default_permissions/tokens.py b/datasette/default_permissions/tokens.py index 474b0c23..f2769284 100644 --- a/datasette/default_permissions/tokens.py +++ b/datasette/default_permissions/tokens.py @@ -93,3 +93,19 @@ def actor_from_signed_api_token(datasette: "Datasette", request) -> Optional[dic actor["token_expires"] = created + duration return actor + + +@hookimpl(trylast=True, specname="create_token") +def create_signed_api_token(datasette, actor_id, expires_after, restrict_all, restrict_database, restrict_resource): + """Default create_token implementation: creates a signed dstok_ token. + + Runs last so that plugins like datasette-auth-tokens can override + by returning a database-backed token first. + """ + return datasette.create_signed_token( + actor_id, + expires_after=expires_after, + restrict_all=restrict_all, + restrict_database=restrict_database, + restrict_resource=restrict_resource, + ) diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 89be6a65..8903b14f 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -242,3 +242,25 @@ def write_wrapper(datasette, database, request, transaction): Return ``None`` to skip wrapping. """ + + +@hookspec +def create_token(datasette, actor_id, expires_after, restrict_all, restrict_database, restrict_resource): + """Create an API token for the given actor. + + Return a token string, or ``None`` to let the next implementation + handle token creation. Implementations may be synchronous or + async (return an awaitable). + + Parameters mirror ``Datasette.create_token()``: + - ``actor_id``: the actor ID to embed in the token. + - ``expires_after``: seconds until expiry, or ``None``. + - ``restrict_all``: list of action names to restrict globally. + - ``restrict_database``: ``{database: [actions]}`` restrictions. + - ``restrict_resource``: ``{database: {resource: [actions]}}`` + restrictions. + + The default (``trylast``) implementation creates a signed + ``dstok_`` token. Plugins like ``datasette-auth-tokens`` can + override this to create database-backed tokens instead. + """ diff --git a/datasette/views/special.py b/datasette/views/special.py index 640c82eb..80040d20 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -731,7 +731,7 @@ class CreateTokenView(BaseView): resource, [] ).append(action) - token = self.ds.create_token( + token = self.ds.create_signed_token( request.actor["id"], expires_after=expires_after, restrict_all=restrict_all, diff --git a/tests/test_api_write.py b/tests/test_api_write.py index 05835e51..afc853f5 100644 --- a/tests/test_api_write.py +++ b/tests/test_api_write.py @@ -1362,7 +1362,7 @@ async def test_create_table( async def test_create_table_permissions( ds_write, permissions, body, expected_status, expected_errors ): - token = ds_write.create_token("root", restrict_all=["view-instance"] + permissions) + token = ds_write.create_signed_token("root", restrict_all=["view-instance"] + permissions) response = await ds_write.client.post( "/data/-/create", json=body, diff --git a/tests/test_auth.py b/tests/test_auth.py index 1e1cd622..cda09e08 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -491,3 +491,152 @@ async def test_root_without_root_enabled_no_special_permissions(ds_client): ) is not True ), "Root without root_enabled should not automatically get drop-table" + + +@pytest.mark.asyncio +async def test_create_signed_token_method(ds_client): + """create_signed_token() creates a dstok_ token synchronously""" + token = ds_client.ds.create_signed_token("test_actor") + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert "t" in decoded + + +@pytest.mark.asyncio +async def test_create_signed_token_with_restrictions(ds_client): + """create_signed_token() respects restriction parameters""" + token = ds_client.ds.create_signed_token( + "test_actor", + expires_after=3600, + restrict_all=["view-instance"], + ) + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert decoded["d"] == 3600 + assert "_r" in decoded + assert "a" in decoded["_r"] + + +@pytest.mark.asyncio +async def test_build_token_restrictions(ds_client): + """build_token_restrictions() returns abbreviated restriction dicts""" + ds = ds_client.ds + # No restrictions returns None + assert ds.build_token_restrictions() is None + + # With restrict_all + result = ds.build_token_restrictions(restrict_all=["view-instance"]) + assert result is not None + assert "a" in result + + # With restrict_database + result = ds.build_token_restrictions( + restrict_database={"mydb": ["view-table"]} + ) + assert result is not None + assert "d" in result + assert "mydb" in result["d"] + + # With restrict_resource + result = ds.build_token_restrictions( + restrict_resource={"mydb": {"mytable": ["insert-row"]}} + ) + assert result is not None + assert "r" in result + assert "mydb" in result["r"] + assert "mytable" in result["r"]["mydb"] + + +@pytest.mark.asyncio +async def test_create_token_hook_default(ds_client): + """The async create_token() method uses the default hook to create signed tokens""" + token = await ds_client.ds.create_token("test_actor") + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + + +@pytest.mark.asyncio +async def test_create_token_hook_default_with_restrictions(ds_client): + """The async create_token() with restrictions creates proper signed tokens""" + token = await ds_client.ds.create_token( + "test_actor", + expires_after=7200, + restrict_all=["view-instance"], + restrict_database={"fixtures": ["view-table"]}, + ) + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert decoded["d"] == 7200 + assert "_r" in decoded + + +@pytest.mark.asyncio +async def test_create_token_hook_can_be_overridden(ds_client): + """A plugin can override create_token to return a custom token""" + from datasette import hookimpl + from datasette.plugins import pm + + class CustomTokenPlugin: + __name__ = "custom_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def custom_create_token(datasette, actor_id, expires_after, restrict_all, restrict_database, restrict_resource): + return "custom_token_for_{}".format(actor_id) + + pm.register(CustomTokenPlugin, name="custom_token_test_plugin") + try: + token = await ds_client.ds.create_token("myactor") + assert token == "custom_token_for_myactor" + finally: + pm.unregister(name="custom_token_test_plugin") + + +@pytest.mark.asyncio +async def test_create_token_hook_async_override(ds_client): + """A plugin can override create_token with an async implementation""" + from datasette import hookimpl + from datasette.plugins import pm + + class AsyncTokenPlugin: + __name__ = "async_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def async_create_token(datasette, actor_id, expires_after, restrict_all, restrict_database, restrict_resource): + async def inner(): + return "async_token_for_{}".format(actor_id) + return inner() + + pm.register(AsyncTokenPlugin, name="async_token_test_plugin") + try: + token = await ds_client.ds.create_token("myactor") + assert token == "async_token_for_myactor" + finally: + pm.unregister(name="async_token_test_plugin") + + +@pytest.mark.asyncio +async def test_create_token_hook_none_falls_through(ds_client): + """If a plugin returns None, the default signed token implementation is used""" + from datasette import hookimpl + from datasette.plugins import pm + + class NoneTokenPlugin: + __name__ = "none_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def none_create_token(datasette, actor_id, expires_after, restrict_all, restrict_database, restrict_resource): + return None + + pm.register(NoneTokenPlugin, name="none_token_test_plugin") + try: + token = await ds_client.ds.create_token("test_actor") + # Should fall through to default signed token + assert token.startswith("dstok_") + finally: + pm.unregister(name="none_token_test_plugin") diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 96c0cf6f..5adcac5b 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1657,7 +1657,7 @@ async def test_permission_check_view_requires_debug_permission(): # Root user should have access (root has all permissions) ds_with_root = Datasette() ds_with_root.root_enabled = True - root_token = ds_with_root.create_token("root") + root_token = ds_with_root.create_signed_token("root") response = await ds_with_root.client.get( "/-/check.json?action=view-instance", headers={"Authorization": f"Bearer {root_token}"},