diff --git a/datasette/utils/asyncdi.py b/datasette/utils/asyncdi.py new file mode 100644 index 00000000..4c7f3a60 --- /dev/null +++ b/datasette/utils/asyncdi.py @@ -0,0 +1,101 @@ +import asyncio +from functools import wraps +import inspect + +try: + import graphlib +except ImportError: + from . import vendored_graphlib as graphlib + + +class AsyncMeta(type): + def __new__(cls, name, bases, attrs): + # Decorate any items that are 'async def' methods + _registry = {} + new_attrs = {"_registry": _registry} + for key, value in attrs.items(): + if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve": + new_attrs[key] = make_method(value) + _registry[key] = new_attrs[key] + else: + new_attrs[key] = value + # Gather graph for later dependency resolution + graph = { + key: { + p + for p in inspect.signature(method).parameters.keys() + if p != "self" and not p.startswith("_") + } + for key, method in _registry.items() + } + new_attrs["_graph"] = graph + return super().__new__(cls, name, bases, new_attrs) + + +def make_method(method): + parameters = inspect.signature(method).parameters.keys() + + @wraps(method) + async def inner(self, _results=None, **kwargs): + # Any parameters not provided by kwargs are resolved from registry + to_resolve = [p for p in parameters if p not in kwargs and p != "self"] + missing = [p for p in to_resolve if p not in self._registry] + assert ( + not missing + ), "The following DI parameters could not be found in the registry: {}".format( + missing + ) + + results = {} + results.update(kwargs) + if to_resolve: + resolved_parameters = await self.resolve(to_resolve, _results) + results.update(resolved_parameters) + return_value = await method(self, **results) + if _results is not None: + _results[method.__name__] = return_value + return return_value + + return inner + + +class AsyncBase(metaclass=AsyncMeta): + async def resolve(self, names, results=None): + if results is None: + results = {} + + # Come up with an execution plan, just for these nodes + ts = graphlib.TopologicalSorter() + to_do = set(names) + done = set() + while to_do: + item = to_do.pop() + dependencies = self._graph[item] + ts.add(item, *dependencies) + done.add(item) + # Add any not-done dependencies to the queue + to_do.update({k for k in dependencies if k not in done}) + + ts.prepare() + plan = [] + while ts.is_active(): + node_group = ts.get_ready() + plan.append(node_group) + ts.done(*node_group) + + results = {} + for node_group in plan: + awaitables = [ + self._registry[name]( + self, + _results=results, + **{k: v for k, v in results.items() if k in self._graph[name]}, + ) + for name in node_group + ] + awaitable_results = await asyncio.gather(*awaitables) + results.update( + {p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)} + ) + + return {key: value for key, value in results.items() if key in names} diff --git a/tests/test_asyncdi.py b/tests/test_asyncdi.py new file mode 100644 index 00000000..68dcb2fd --- /dev/null +++ b/tests/test_asyncdi.py @@ -0,0 +1,80 @@ +import asyncio +from datasette.utils.asyncdi import AsyncBase +import pytest +from random import random + + +class Simple(AsyncBase): + def __init__(self): + self.log = [] + + async def two(self): + self.log.append("two") + + async def one(self, two): + self.log.append("one") + return self.log + + +class Complex(AsyncBase): + def __init__(self): + self.log = [] + + async def d(self): + await asyncio.sleep(random() * 0.1) + self.log.append("d") + + async def c(self): + await asyncio.sleep(random() * 0.1) + self.log.append("c") + + async def b(self, c, d): + self.log.append("b") + + async def a(self, b, c): + self.log.append("a") + + async def go(self, a): + self.log.append("go") + return self.log + + +class WithParameters(AsyncBase): + async def go(self, calc1, calc2, param1): + return param1 + calc1 + calc2 + + async def calc1(self): + return 5 + + async def calc2(self): + return 6 + + +@pytest.mark.asyncio +async def test_simple(): + assert await Simple().one() == ["two", "one"] + + +@pytest.mark.asyncio +async def test_complex(): + result = await Complex().go() + # 'c' should only be called once + assert tuple(result) in ( + # c and d could happen in either order + ("c", "d", "b", "a", "go"), + ("d", "c", "b", "a", "go"), + ) + + +@pytest.mark.asyncio +async def test_with_parameters(): + result = await WithParameters().go(param1=4) + assert result == 15 + + # Should throw an error if that parameter is missing + with pytest.raises(AssertionError) as e: + await WithParameters().go() + assert e.args[0] == ( + "The following DI parameters could not be " + "found in the registry: ['param1']" + )