diff --git a/datasette/tokens.py b/datasette/tokens.py index 5a12d8e0..38a55529 100644 --- a/datasette/tokens.py +++ b/datasette/tokens.py @@ -52,6 +52,38 @@ class TokenRestrictions: self.resource.setdefault(database, {}).setdefault(resource, []).append(action) return self + def abbreviated(self, datasette: "Datasette") -> Optional[dict]: + """ + Return the abbreviated ``_r`` dictionary shape for this set of + restrictions, using action abbreviations registered with ``datasette``. + Returns ``None`` if no restrictions are set. + """ + if not (self.all or self.database or self.resource): + return None + + def abbreviate_action(action): + action_obj = datasette.actions.get(action) + if not action_obj: + return action + return action_obj.abbr or action + + result: dict = {} + if self.all: + result["a"] = [abbreviate_action(a) for a in self.all] + if self.database: + result["d"] = { + database: [abbreviate_action(a) for a in actions] + for database, actions in self.database.items() + } + if self.resource: + result["r"] = {} + for database, resources in self.resource.items(): + for resource, actions in resources.items(): + result["r"].setdefault(database, {})[resource] = [ + abbreviate_action(a) for a in actions + ] + return result + class TokenHandler: """ @@ -104,31 +136,12 @@ class SignedTokenHandler(TokenHandler): token = {"a": actor_id, "t": int(time.time())} - def abbreviate_action(action): - action_obj = datasette.actions.get(action) - if not action_obj: - return action - return action_obj.abbr or action - if expires_after: token["d"] = expires_after - if restrictions and ( - restrictions.all or restrictions.database or restrictions.resource - ): - token["_r"] = {} - if restrictions.all: - token["_r"]["a"] = [abbreviate_action(a) for a in restrictions.all] - if restrictions.database: - token["_r"]["d"] = {} - for database, actions in restrictions.database.items(): - token["_r"]["d"][database] = [abbreviate_action(a) for a in actions] - if restrictions.resource: - token["_r"]["r"] = {} - for database, resources in restrictions.resource.items(): - for resource, actions in resources.items(): - token["_r"]["r"].setdefault(database, {})[resource] = [ - abbreviate_action(a) for a in actions - ] + if restrictions is not None: + abbreviated = restrictions.abbreviated(datasette) + if abbreviated is not None: + token["_r"] = abbreviated return "dstok_{}".format(datasette.sign(token, namespace="token")) async def verify_token(self, datasette: "Datasette", token: str) -> Optional[dict]: diff --git a/docs/internals.rst b/docs/internals.rst index 2710345b..e0123a7b 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -729,6 +729,30 @@ The builder methods are: Each method returns the ``TokenRestrictions`` instance so calls can be chained. +``TokenRestrictions`` also provides an ``abbreviated(datasette)`` method which returns the restrictions as a dictionary using the compact format described in :ref:`authentication_cli_create_token_restrict`, with action names replaced by their registered abbreviations. It returns the inner dictionary only - the ``"_r"`` wrapping key shown in that section is not included. Returns ``None`` if no restrictions are set. This is useful when writing a custom :ref:`plugin_hook_register_token_handler` that needs to embed restrictions in a token payload. + +For example, the following restrictions: + +.. code-block:: python + + restrictions = ( + TokenRestrictions() + .allow_all("view-instance") + .allow_database("docs", "view-query") + .allow_resource("docs", "attachments", "insert-row") + ) + restrictions.abbreviated(datasette) + +Returns this dictionary, using the abbreviations registered for each action: + +.. code-block:: python + + { + "a": ["vi"], + "d": {"docs": ["vq"]}, + "r": {"docs": {"attachments": ["ir"]}}, + } + The following example creates a token that can access ``view-instance`` and ``view-table`` across everything, can additionally use ``view-query`` for anything in the ``docs`` database and is allowed to execute ``insert-row`` and ``update-row`` in the ``attachments`` table in that database: .. code-block:: python diff --git a/tests/test_token_handler.py b/tests/test_token_handler.py index 83f09046..5c87f577 100644 --- a/tests/test_token_handler.py +++ b/tests/test_token_handler.py @@ -291,6 +291,43 @@ async def test_expires_after_round_trip(datasette): assert "token_expires" in actor +@pytest.mark.asyncio +@pytest.mark.parametrize( + "build_restrictions,expected", + [ + (lambda r: r, None), + (lambda r: r.allow_all("view-instance"), {"a": ["vi"]}), + ( + lambda r: r.allow_database("docs", "view-query"), + {"d": {"docs": ["vq"]}}, + ), + ( + lambda r: r.allow_resource("docs", "attachments", "insert-row"), + {"r": {"docs": {"attachments": ["ir"]}}}, + ), + ( + lambda r: r.allow_all("view-instance") + .allow_database("docs", "view-query") + .allow_resource("docs", "attachments", "insert-row"), + { + "a": ["vi"], + "d": {"docs": ["vq"]}, + "r": {"docs": {"attachments": ["ir"]}}, + }, + ), + ( + lambda r: r.allow_all("not-a-real-action"), + {"a": ["not-a-real-action"]}, + ), + ], + ids=["empty", "all", "database", "resource", "combined", "unknown_action"], +) +async def test_token_restrictions_abbreviated(datasette, build_restrictions, expected): + await datasette.invoke_startup() + restrictions = build_restrictions(TokenRestrictions()) + assert restrictions.abbreviated(datasette) == expected + + @pytest.mark.asyncio async def test_signed_tokens_disabled(): """create_token and verify_token should fail/skip when signed tokens are disabled."""