diff --git a/datasette/utils/asyncdi.py b/datasette/utils/asyncdi.py index 4c7f3a60..ba743450 100644 --- a/datasette/utils/asyncdi.py +++ b/datasette/utils/asyncdi.py @@ -8,13 +8,23 @@ except ImportError: from . import vendored_graphlib as graphlib +def inject(fn): + fn._inject = True + return fn + + class AsyncMeta(type): def __new__(cls, name, bases, attrs): # Decorate any items that are 'async def' methods _registry = {} new_attrs = {"_registry": _registry} + inject_all = attrs.get("inject_all") for key, value in attrs.items(): - if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve": + if ( + inspect.iscoroutinefunction(value) + and not value.__name__ == "resolve" + and (inject_all or getattr(value, "_inject", None)) + ): new_attrs[key] = make_method(value) _registry[key] = new_attrs[key] else: diff --git a/tests/test_asyncdi.py b/tests/test_asyncdi.py index 68dcb2fd..baf8a469 100644 --- a/tests/test_asyncdi.py +++ b/tests/test_asyncdi.py @@ -1,5 +1,5 @@ import asyncio -from datasette.utils.asyncdi import AsyncBase +from datasette.utils.asyncdi import AsyncBase, inject import pytest from random import random @@ -8,15 +8,22 @@ class Simple(AsyncBase): def __init__(self): self.log = [] + @inject async def two(self): self.log.append("two") + @inject async def one(self, two): self.log.append("one") return self.log + async def not_inject(self, one, two): + return one + two + class Complex(AsyncBase): + inject_all = True + def __init__(self): self.log = [] @@ -40,6 +47,8 @@ class Complex(AsyncBase): class WithParameters(AsyncBase): + inject_all = True + async def go(self, calc1, calc2, param1): return param1 + calc1 + calc2 @@ -53,6 +62,7 @@ class WithParameters(AsyncBase): @pytest.mark.asyncio async def test_simple(): assert await Simple().one() == ["two", "one"] + assert await Simple().not_inject(6, 7) == 13 @pytest.mark.asyncio