diff --git a/datasette/extras.py b/datasette/extras.py index fee92939..2c3450b2 100644 --- a/datasette/extras.py +++ b/datasette/extras.py @@ -1,3 +1,4 @@ +import contextvars import re from dataclasses import dataclass from enum import Enum @@ -5,6 +6,11 @@ from typing import ClassVar from asyncinject import Registry +# Per-request context for Extra.resolve(), so the asyncinject registries can +# be shared across requests. asyncio tasks copy the caller's context, so +# concurrent resolve() calls each see their own value. +_resolve_context = contextvars.ContextVar("datasette_extras_context") + def extra_names_from_request(request): extra_bits = request.args.getlist("_extra") @@ -62,6 +68,13 @@ class ExtraRegistry: def __init__(self, classes): self.classes = list(classes) self.classes_by_name = {cls.key(): cls for cls in self.classes} + # Lazily-built shared state, keyed by scope. Safe to share across + # requests because Extra instances are stateless and asyncinject's + # Registry keeps per-call state local to each resolve_multi() call. + # If extras classes ever become registerable at runtime (e.g. via a + # plugin hook) these caches will need invalidating. + self._scope_registries = {} + self._allowed_names = {} def classes_for_scope(self, scope, include_internal=True): classes = [ @@ -74,23 +87,43 @@ class ExtraRegistry: def public_classes_for_scope(self, scope): return self.classes_for_scope(scope, include_internal=False) + def _registry_for_scope(self, scope): + registry = self._scope_registries.get(scope) + if registry is None: + registry = Registry() + + async def context_provider(): + return _resolve_context.get() + + registry.register(context_provider, name="context") + for cls in self.classes_for_scope(scope): + registry.register(cls().resolve, name=cls.key()) + self._scope_registries[scope] = registry + return registry + + def _allowed_names_for_scope(self, scope, include_internal): + key = (scope, include_internal) + names = self._allowed_names.get(key) + if names is None: + names = { + cls.key() + for cls in self.classes_for_scope( + scope, include_internal=include_internal + ) + } + self._allowed_names[key] = names + return names + async def resolve(self, requested, context, scope, include_internal=False): - registry = Registry() - - async def context_provider(): - return context - - registry.register(context_provider, name="context") - - for cls in self.classes_for_scope(scope): - registry.register(cls().resolve, name=cls.key()) - - allowed_names = { - cls.key() - for cls in self.classes_for_scope(scope, include_internal=include_internal) - } + allowed_names = self._allowed_names_for_scope(scope, include_internal) requested_names = [name for name in requested if name in allowed_names] - resolved = await registry.resolve_multi(requested_names) + token = _resolve_context.set(context) + try: + resolved = await self._registry_for_scope(scope).resolve_multi( + requested_names + ) + finally: + _resolve_context.reset(token) return {name: resolved[name] for name in requested_names} diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 00000000..ad8a9f00 --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,65 @@ +import asyncio + +import pytest + +from datasette.extras import Extra, ExtraRegistry, ExtraScope + + +class SlowValueExtra(Extra): + description = "Returns context['value'], optionally slowly" + scopes = {ExtraScope.TABLE} + + async def resolve(self, context): + if context["slow"]: + await asyncio.sleep(0.05) + return context["value"] + + +class DependentExtra(Extra): + description = "Depends on slow_value" + scopes = {ExtraScope.TABLE} + + async def resolve(self, context, slow_value): + return slow_value + 1 + + +def test_registry_is_built_once_per_scope(): + registry = ExtraRegistry([SlowValueExtra, DependentExtra]) + first = registry._registry_for_scope(ExtraScope.TABLE) + second = registry._registry_for_scope(ExtraScope.TABLE) + assert first is second + + +@pytest.mark.asyncio +async def test_concurrent_resolves_do_not_share_state(): + # The asyncinject registry is shared across requests - resolved values + # must not leak between concurrent resolve() calls with different contexts + registry = ExtraRegistry([SlowValueExtra, DependentExtra]) + slow, fast = await asyncio.gather( + registry.resolve( + {"slow_value", "dependent"}, + {"value": 100, "slow": True}, + ExtraScope.TABLE, + ), + registry.resolve( + {"slow_value", "dependent"}, + {"value": 200, "slow": False}, + ExtraScope.TABLE, + ), + ) + assert slow == {"slow_value": 100, "dependent": 101} + assert fast == {"slow_value": 200, "dependent": 201} + + +@pytest.mark.asyncio +async def test_table_row_and_query_scopes_use_separate_registries(): + from datasette.views.table_extras import table_extra_registry + + registries = { + scope: table_extra_registry._registry_for_scope(scope) for scope in ExtraScope + } + assert len(set(map(id, registries.values()))) == 3 + # Scope-specific extras only registered where they belong + assert "count" in registries[ExtraScope.TABLE]._registry + assert "count" not in registries[ExtraScope.QUERY]._registry + assert "foreign_key_tables" in registries[ExtraScope.ROW]._registry