New AsyncBase class, refs #878

This commit is contained in:
Simon Willison 2021-11-16 13:54:14 -08:00
commit 86aaa7c7b2
2 changed files with 181 additions and 0 deletions

101
datasette/utils/asyncdi.py Normal file
View file

@ -0,0 +1,101 @@
import asyncio
from functools import wraps
import inspect
try:
import graphlib
except ImportError:
from . import vendored_graphlib as graphlib
class AsyncMeta(type):
def __new__(cls, name, bases, attrs):
# Decorate any items that are 'async def' methods
_registry = {}
new_attrs = {"_registry": _registry}
for key, value in attrs.items():
if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
new_attrs[key] = make_method(value)
_registry[key] = new_attrs[key]
else:
new_attrs[key] = value
# Gather graph for later dependency resolution
graph = {
key: {
p
for p in inspect.signature(method).parameters.keys()
if p != "self" and not p.startswith("_")
}
for key, method in _registry.items()
}
new_attrs["_graph"] = graph
return super().__new__(cls, name, bases, new_attrs)
def make_method(method):
parameters = inspect.signature(method).parameters.keys()
@wraps(method)
async def inner(self, _results=None, **kwargs):
# Any parameters not provided by kwargs are resolved from registry
to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
missing = [p for p in to_resolve if p not in self._registry]
assert (
not missing
), "The following DI parameters could not be found in the registry: {}".format(
missing
)
results = {}
results.update(kwargs)
if to_resolve:
resolved_parameters = await self.resolve(to_resolve, _results)
results.update(resolved_parameters)
return_value = await method(self, **results)
if _results is not None:
_results[method.__name__] = return_value
return return_value
return inner
class AsyncBase(metaclass=AsyncMeta):
async def resolve(self, names, results=None):
if results is None:
results = {}
# Come up with an execution plan, just for these nodes
ts = graphlib.TopologicalSorter()
to_do = set(names)
done = set()
while to_do:
item = to_do.pop()
dependencies = self._graph[item]
ts.add(item, *dependencies)
done.add(item)
# Add any not-done dependencies to the queue
to_do.update({k for k in dependencies if k not in done})
ts.prepare()
plan = []
while ts.is_active():
node_group = ts.get_ready()
plan.append(node_group)
ts.done(*node_group)
results = {}
for node_group in plan:
awaitables = [
self._registry[name](
self,
_results=results,
**{k: v for k, v in results.items() if k in self._graph[name]},
)
for name in node_group
]
awaitable_results = await asyncio.gather(*awaitables)
results.update(
{p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)}
)
return {key: value for key, value in results.items() if key in names}

80
tests/test_asyncdi.py Normal file
View file

@ -0,0 +1,80 @@
import asyncio
from datasette.utils.asyncdi import AsyncBase
import pytest
from random import random
class Simple(AsyncBase):
def __init__(self):
self.log = []
async def two(self):
self.log.append("two")
async def one(self, two):
self.log.append("one")
return self.log
class Complex(AsyncBase):
def __init__(self):
self.log = []
async def d(self):
await asyncio.sleep(random() * 0.1)
self.log.append("d")
async def c(self):
await asyncio.sleep(random() * 0.1)
self.log.append("c")
async def b(self, c, d):
self.log.append("b")
async def a(self, b, c):
self.log.append("a")
async def go(self, a):
self.log.append("go")
return self.log
class WithParameters(AsyncBase):
async def go(self, calc1, calc2, param1):
return param1 + calc1 + calc2
async def calc1(self):
return 5
async def calc2(self):
return 6
@pytest.mark.asyncio
async def test_simple():
assert await Simple().one() == ["two", "one"]
@pytest.mark.asyncio
async def test_complex():
result = await Complex().go()
# 'c' should only be called once
assert tuple(result) in (
# c and d could happen in either order
("c", "d", "b", "a", "go"),
("d", "c", "b", "a", "go"),
)
@pytest.mark.asyncio
async def test_with_parameters():
result = await WithParameters().go(param1=4)
assert result == 15
# Should throw an error if that parameter is missing
with pytest.raises(AssertionError) as e:
await WithParameters().go()
assert e.args[0] == (
"The following DI parameters could not be "
"found in the registry: ['param1']"
)