diff --git a/datasette/app.py b/datasette/app.py index 6efaa430..458f63da 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -6,7 +6,7 @@ import contextvars from typing import TYPE_CHECKING, Any, Dict, Iterable, List if TYPE_CHECKING: - from datasette.permissions import Resource + from datasette.permissions import Resource, TokenRestrictions import asgi_csrf import collections import dataclasses @@ -713,45 +713,95 @@ class Datasette: """ return _in_datasette_client.get() - def create_token( + def _abbreviate_action(self, action: str) -> str: + """Return the abbreviated form of an action name if one exists.""" + action_obj = self.actions.get(action) + if not action_obj: + return action + return action_obj.abbr or action + + def build_token_restrictions( + self, + restrictions: "TokenRestrictions", + ) -> dict | None: + """Build an abbreviated restrictions dict for use in token payloads. + + Returns a dict like ``{"a": [...], "d": {...}, "r": {...}}`` + or ``None`` if there are no restrictions. + """ + if not (restrictions.all or restrictions.database or restrictions.resource): + return None + result: dict = {} + if restrictions.all: + result["a"] = [ + self._abbreviate_action(a) for a in restrictions.all + ] + if restrictions.database: + result["d"] = { + database: [self._abbreviate_action(a) for a in actions] + for database, actions in restrictions.database.items() + } + if restrictions.resource: + result["r"] = {} + for database, resources in restrictions.resource.items(): + for resource, actions in resources.items(): + result["r"].setdefault(database, {})[resource] = [ + self._abbreviate_action(a) for a in actions + ] + return result + + def create_signed_token( self, actor_id: str, *, expires_after: int | None = None, - restrict_all: Iterable[str] | None = None, - restrict_database: Dict[str, Iterable[str]] | None = None, - restrict_resource: Dict[str, Dict[str, Iterable[str]]] | None = None, - ): - token = {"a": actor_id, "t": int(time.time())} + restrictions: "TokenRestrictions | None" = None, + ) -> str: + """Create a signed ``dstok_`` API token. + + This always creates a signed token regardless of installed plugins. + Use :meth:`create_token` to go through the plugin hook instead. + """ + from datasette.permissions import TokenRestrictions + + token: dict = {"a": actor_id, "t": int(time.time())} if expires_after: token["d"] = expires_after - - def abbreviate_action(action): - # rename to abbr if possible - action_obj = self.actions.get(action) - if not action_obj: - return action - return action_obj.abbr or action - - if expires_after: - token["d"] = expires_after - if restrict_all or restrict_database or restrict_resource: - token["_r"] = {} - if restrict_all: - token["_r"]["a"] = [abbreviate_action(a) for a in restrict_all] - if restrict_database: - token["_r"]["d"] = {} - for database, actions in restrict_database.items(): - token["_r"]["d"][database] = [abbreviate_action(a) for a in actions] - if restrict_resource: - token["_r"]["r"] = {} - for database, resources in restrict_resource.items(): - for resource, actions in resources.items(): - token["_r"]["r"].setdefault(database, {})[resource] = [ - abbreviate_action(a) for a in actions - ] + if restrictions is None: + restrictions = TokenRestrictions(all=[], database={}, resource={}) + abbreviated = self.build_token_restrictions(restrictions) + if abbreviated: + token["_r"] = abbreviated return "dstok_{}".format(self.sign(token, namespace="token")) + async def create_token( + self, + actor_id: str, + *, + expires_after: int | None = None, + restrictions: "TokenRestrictions | None" = None, + ) -> str: + """Create an API token, dispatching through the ``create_token`` hook. + + If a plugin implements the ``create_token`` hook, it can return + a database-backed token instead of a signed one. The default + implementation creates a signed ``dstok_`` token. + """ + from datasette.permissions import TokenRestrictions + + if restrictions is None: + restrictions = TokenRestrictions(all=[], database={}, resource={}) + for result in pm.hook.create_token( + datasette=self, + actor_id=actor_id, + expires_after=expires_after, + restrictions=restrictions, + ): + result = await await_me_maybe(result) + if result is not None: + return result + raise RuntimeError("No create_token hook implementation returned a token") + def get_database(self, name=None, route=None): if route is not None: matches = [db for db in self.databases.values() if db.route == route] diff --git a/datasette/cli.py b/datasette/cli.py index 121911ab..b8a46b12 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -841,12 +841,17 @@ def create_token( action ) - token = ds.create_token( + from datasette.permissions import TokenRestrictions + + restrictions = TokenRestrictions( + all=alls, + database=restrict_database, + resource=restrict_resource, + ) + token = ds.create_signed_token( id, expires_after=expires_after, - restrict_all=alls, - restrict_database=restrict_database, - restrict_resource=restrict_resource, + restrictions=restrictions, ) click.echo(token) if debug: diff --git a/datasette/default_permissions/__init__.py b/datasette/default_permissions/__init__.py index 40373fa7..144fcafe 100644 --- a/datasette/default_permissions/__init__.py +++ b/datasette/default_permissions/__init__.py @@ -38,6 +38,7 @@ from .defaults import ( DEFAULT_ALLOW_ACTIONS as DEFAULT_ALLOW_ACTIONS, ) from .tokens import actor_from_signed_api_token as actor_from_signed_api_token +from .tokens import create_signed_api_token as create_signed_api_token @hookimpl diff --git a/datasette/default_permissions/tokens.py b/datasette/default_permissions/tokens.py index 474b0c23..67e99e5a 100644 --- a/datasette/default_permissions/tokens.py +++ b/datasette/default_permissions/tokens.py @@ -93,3 +93,17 @@ def actor_from_signed_api_token(datasette: "Datasette", request) -> Optional[dic actor["token_expires"] = created + duration return actor + + +@hookimpl(trylast=True, specname="create_token") +def create_signed_api_token(datasette, actor_id, expires_after, restrictions): + """Default create_token implementation: creates a signed dstok_ token. + + Runs last so that plugins like datasette-auth-tokens can override + by returning a database-backed token first. + """ + return datasette.create_signed_token( + actor_id, + expires_after=expires_after, + restrictions=restrictions, + ) diff --git a/datasette/events.py b/datasette/events.py index 5cd5ba3d..35ddb085 100644 --- a/datasette/events.py +++ b/datasette/events.py @@ -53,19 +53,13 @@ class CreateTokenEvent(Event): :ivar expires_after: Number of seconds after which this token will expire. :type expires_after: int or None - :ivar restrict_all: Restricted permissions for this token. - :type restrict_all: list - :ivar restrict_database: Restricted database permissions for this token. - :type restrict_database: dict - :ivar restrict_resource: Restricted resource permissions for this token. - :type restrict_resource: dict + :ivar restrictions: Token restrictions (a :class:`TokenRestrictions` instance). + :type restrictions: TokenRestrictions """ name = "create-token" expires_after: int | None - restrict_all: list - restrict_database: dict - restrict_resource: dict + restrictions: object # TokenRestrictions @dataclass diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 89be6a65..88457116 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -242,3 +242,23 @@ def write_wrapper(datasette, database, request, transaction): Return ``None`` to skip wrapping. """ + + +@hookspec +def create_token(datasette, actor_id, expires_after, restrictions): + """Create an API token for the given actor. + + Return a token string, or ``None`` to let the next implementation + handle token creation. Implementations may be synchronous or + async (return an awaitable). + + Parameters mirror ``Datasette.create_token()``: + - ``actor_id``: the actor ID to embed in the token. + - ``expires_after``: seconds until expiry, or ``None``. + - ``restrictions``: a :class:`TokenRestrictions` dataclass with + ``all``, ``database``, and ``resource`` fields. + + The default (``trylast``) implementation creates a signed + ``dstok_`` token. Plugins like ``datasette-auth-tokens`` can + override this to create database-backed tokens instead. + """ diff --git a/datasette/permissions.py b/datasette/permissions.py index b5e72b8e..dd2db966 100644 --- a/datasette/permissions.py +++ b/datasette/permissions.py @@ -121,6 +121,19 @@ class AllowedResource(NamedTuple): reason: str +@dataclass +class TokenRestrictions: + """Restrictions that can be applied to an API token. + + ``all`` restricts globally, ``database`` restricts per-database, + and ``resource`` restricts per-resource within a database. + """ + + all: list[str] + database: dict[str, list[str]] + resource: dict[str, dict[str, list[str]]] + + @dataclass(frozen=True, kw_only=True) class Action: name: str diff --git a/datasette/views/special.py b/datasette/views/special.py index 640c82eb..e8962873 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -1,6 +1,7 @@ import json import logging from datasette.events import LogoutEvent, LoginEvent, CreateTokenEvent +from datasette.permissions import TokenRestrictions from datasette.resources import DatabaseResource, TableResource from datasette.utils.asgi import Response, Forbidden from datasette.utils import ( @@ -731,21 +732,22 @@ class CreateTokenView(BaseView): resource, [] ).append(action) - token = self.ds.create_token( + restrictions = TokenRestrictions( + all=restrict_all, + database=restrict_database, + resource=restrict_resource, + ) + token = self.ds.create_signed_token( request.actor["id"], expires_after=expires_after, - restrict_all=restrict_all, - restrict_database=restrict_database, - restrict_resource=restrict_resource, + restrictions=restrictions, ) token_bits = self.ds.unsign(token[len("dstok_") :], namespace="token") await self.ds.track_event( CreateTokenEvent( actor=request.actor, expires_after=expires_after, - restrict_all=restrict_all, - restrict_database=restrict_database, - restrict_resource=restrict_resource, + restrictions=restrictions, ) ) context = await self.shared(request) diff --git a/tests/test_api_write.py b/tests/test_api_write.py index 05835e51..ecde33c4 100644 --- a/tests/test_api_write.py +++ b/tests/test_api_write.py @@ -1,4 +1,5 @@ from datasette.app import Datasette +from datasette.permissions import TokenRestrictions from datasette.utils import sqlite3 from .utils import last_event import pytest @@ -1362,7 +1363,7 @@ async def test_create_table( async def test_create_table_permissions( ds_write, permissions, body, expected_status, expected_errors ): - token = ds_write.create_token("root", restrict_all=["view-instance"] + permissions) + token = ds_write.create_signed_token("root", restrictions=TokenRestrictions(all=["view-instance"] + permissions, database={}, resource={})) response = await ds_write.client.post( "/data/-/create", json=body, diff --git a/tests/test_auth.py b/tests/test_auth.py index 1e1cd622..4459d6d1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,7 @@ from bs4 import BeautifulSoup as Soup from .utils import cookie_was_deleted, last_event from click.testing import CliRunner +from datasette.permissions import TokenRestrictions from datasette.utils import baseconv from datasette.cli import cli from datasette.resources import ( @@ -209,9 +210,7 @@ def test_auth_create_token( event = last_event(app_client.ds) assert event.name == "create-token" assert event.expires_after == expected_duration - assert isinstance(event.restrict_all, list) - assert isinstance(event.restrict_database, dict) - assert isinstance(event.restrict_resource, dict) + assert isinstance(event.restrictions, TokenRestrictions) # Extract token from page token = response2.text.split('value="dstok_')[1].split('"')[0] details = app_client.ds.unsign(token, "token") @@ -491,3 +490,159 @@ async def test_root_without_root_enabled_no_special_permissions(ds_client): ) is not True ), "Root without root_enabled should not automatically get drop-table" + + +@pytest.mark.asyncio +async def test_create_signed_token_method(ds_client): + """create_signed_token() creates a dstok_ token synchronously""" + token = ds_client.ds.create_signed_token("test_actor") + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert "t" in decoded + + +@pytest.mark.asyncio +async def test_create_signed_token_with_restrictions(ds_client): + """create_signed_token() respects restriction parameters""" + token = ds_client.ds.create_signed_token( + "test_actor", + expires_after=3600, + restrictions=TokenRestrictions( + all=["view-instance"], + database={}, + resource={}, + ), + ) + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert decoded["d"] == 3600 + assert "_r" in decoded + assert "a" in decoded["_r"] + + +@pytest.mark.asyncio +async def test_build_token_restrictions(ds_client): + """build_token_restrictions() returns abbreviated restriction dicts""" + ds = ds_client.ds + # No restrictions returns None + assert ds.build_token_restrictions(TokenRestrictions(all=[], database={}, resource={})) is None + + # With all + result = ds.build_token_restrictions(TokenRestrictions(all=["view-instance"], database={}, resource={})) + assert result is not None + assert "a" in result + + # With database + result = ds.build_token_restrictions( + TokenRestrictions(all=[], database={"mydb": ["view-table"]}, resource={}) + ) + assert result is not None + assert "d" in result + assert "mydb" in result["d"] + + # With resource + result = ds.build_token_restrictions( + TokenRestrictions(all=[], database={}, resource={"mydb": {"mytable": ["insert-row"]}}) + ) + assert result is not None + assert "r" in result + assert "mydb" in result["r"] + assert "mytable" in result["r"]["mydb"] + + +@pytest.mark.asyncio +async def test_create_token_hook_default(ds_client): + """The async create_token() method uses the default hook to create signed tokens""" + token = await ds_client.ds.create_token("test_actor") + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + + +@pytest.mark.asyncio +async def test_create_token_hook_default_with_restrictions(ds_client): + """The async create_token() with restrictions creates proper signed tokens""" + token = await ds_client.ds.create_token( + "test_actor", + expires_after=7200, + restrictions=TokenRestrictions( + all=["view-instance"], + database={"fixtures": ["view-table"]}, + resource={}, + ), + ) + assert token.startswith("dstok_") + decoded = ds_client.ds.unsign(token[len("dstok_"):], namespace="token") + assert decoded["a"] == "test_actor" + assert decoded["d"] == 7200 + assert "_r" in decoded + + +@pytest.mark.asyncio +async def test_create_token_hook_can_be_overridden(ds_client): + """A plugin can override create_token to return a custom token""" + from datasette import hookimpl + from datasette.plugins import pm + + class CustomTokenPlugin: + __name__ = "custom_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def custom_create_token(datasette, actor_id, expires_after, restrictions): + return "custom_token_for_{}".format(actor_id) + + pm.register(CustomTokenPlugin, name="custom_token_test_plugin") + try: + token = await ds_client.ds.create_token("myactor") + assert token == "custom_token_for_myactor" + finally: + pm.unregister(name="custom_token_test_plugin") + + +@pytest.mark.asyncio +async def test_create_token_hook_async_override(ds_client): + """A plugin can override create_token with an async implementation""" + from datasette import hookimpl + from datasette.plugins import pm + + class AsyncTokenPlugin: + __name__ = "async_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def async_create_token(datasette, actor_id, expires_after, restrictions): + async def inner(): + return "async_token_for_{}".format(actor_id) + return inner() + + pm.register(AsyncTokenPlugin, name="async_token_test_plugin") + try: + token = await ds_client.ds.create_token("myactor") + assert token == "async_token_for_myactor" + finally: + pm.unregister(name="async_token_test_plugin") + + +@pytest.mark.asyncio +async def test_create_token_hook_none_falls_through(ds_client): + """If a plugin returns None, the default signed token implementation is used""" + from datasette import hookimpl + from datasette.plugins import pm + + class NoneTokenPlugin: + __name__ = "none_token_test_plugin" + + @staticmethod + @hookimpl(specname="create_token") + def none_create_token(datasette, actor_id, expires_after, restrictions): + return None + + pm.register(NoneTokenPlugin, name="none_token_test_plugin") + try: + token = await ds_client.ds.create_token("test_actor") + # Should fall through to default signed token + assert token.startswith("dstok_") + finally: + pm.unregister(name="none_token_test_plugin") diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 96c0cf6f..5adcac5b 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1657,7 +1657,7 @@ async def test_permission_check_view_requires_debug_permission(): # Root user should have access (root has all permissions) ds_with_root = Datasette() ds_with_root.root_enabled = True - root_token = ds_with_root.create_token("root") + root_token = ds_with_root.create_signed_token("root") response = await ds_with_root.client.get( "/-/check.json?action=view-instance", headers={"Authorization": f"Bearer {root_token}"},