mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
New AsyncBase class, refs #878
This commit is contained in:
parent
0156c6b5e5
commit
86aaa7c7b2
2 changed files with 181 additions and 0 deletions
101
datasette/utils/asyncdi.py
Normal file
101
datasette/utils/asyncdi.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio
|
||||
from functools import wraps
|
||||
import inspect
|
||||
|
||||
try:
|
||||
import graphlib
|
||||
except ImportError:
|
||||
from . import vendored_graphlib as graphlib
|
||||
|
||||
|
||||
class AsyncMeta(type):
|
||||
def __new__(cls, name, bases, attrs):
|
||||
# Decorate any items that are 'async def' methods
|
||||
_registry = {}
|
||||
new_attrs = {"_registry": _registry}
|
||||
for key, value in attrs.items():
|
||||
if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
|
||||
new_attrs[key] = make_method(value)
|
||||
_registry[key] = new_attrs[key]
|
||||
else:
|
||||
new_attrs[key] = value
|
||||
# Gather graph for later dependency resolution
|
||||
graph = {
|
||||
key: {
|
||||
p
|
||||
for p in inspect.signature(method).parameters.keys()
|
||||
if p != "self" and not p.startswith("_")
|
||||
}
|
||||
for key, method in _registry.items()
|
||||
}
|
||||
new_attrs["_graph"] = graph
|
||||
return super().__new__(cls, name, bases, new_attrs)
|
||||
|
||||
|
||||
def make_method(method):
|
||||
parameters = inspect.signature(method).parameters.keys()
|
||||
|
||||
@wraps(method)
|
||||
async def inner(self, _results=None, **kwargs):
|
||||
# Any parameters not provided by kwargs are resolved from registry
|
||||
to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
|
||||
missing = [p for p in to_resolve if p not in self._registry]
|
||||
assert (
|
||||
not missing
|
||||
), "The following DI parameters could not be found in the registry: {}".format(
|
||||
missing
|
||||
)
|
||||
|
||||
results = {}
|
||||
results.update(kwargs)
|
||||
if to_resolve:
|
||||
resolved_parameters = await self.resolve(to_resolve, _results)
|
||||
results.update(resolved_parameters)
|
||||
return_value = await method(self, **results)
|
||||
if _results is not None:
|
||||
_results[method.__name__] = return_value
|
||||
return return_value
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class AsyncBase(metaclass=AsyncMeta):
|
||||
async def resolve(self, names, results=None):
|
||||
if results is None:
|
||||
results = {}
|
||||
|
||||
# Come up with an execution plan, just for these nodes
|
||||
ts = graphlib.TopologicalSorter()
|
||||
to_do = set(names)
|
||||
done = set()
|
||||
while to_do:
|
||||
item = to_do.pop()
|
||||
dependencies = self._graph[item]
|
||||
ts.add(item, *dependencies)
|
||||
done.add(item)
|
||||
# Add any not-done dependencies to the queue
|
||||
to_do.update({k for k in dependencies if k not in done})
|
||||
|
||||
ts.prepare()
|
||||
plan = []
|
||||
while ts.is_active():
|
||||
node_group = ts.get_ready()
|
||||
plan.append(node_group)
|
||||
ts.done(*node_group)
|
||||
|
||||
results = {}
|
||||
for node_group in plan:
|
||||
awaitables = [
|
||||
self._registry[name](
|
||||
self,
|
||||
_results=results,
|
||||
**{k: v for k, v in results.items() if k in self._graph[name]},
|
||||
)
|
||||
for name in node_group
|
||||
]
|
||||
awaitable_results = await asyncio.gather(*awaitables)
|
||||
results.update(
|
||||
{p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)}
|
||||
)
|
||||
|
||||
return {key: value for key, value in results.items() if key in names}
|
||||
80
tests/test_asyncdi.py
Normal file
80
tests/test_asyncdi.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import asyncio
|
||||
from datasette.utils.asyncdi import AsyncBase
|
||||
import pytest
|
||||
from random import random
|
||||
|
||||
|
||||
class Simple(AsyncBase):
|
||||
def __init__(self):
|
||||
self.log = []
|
||||
|
||||
async def two(self):
|
||||
self.log.append("two")
|
||||
|
||||
async def one(self, two):
|
||||
self.log.append("one")
|
||||
return self.log
|
||||
|
||||
|
||||
class Complex(AsyncBase):
|
||||
def __init__(self):
|
||||
self.log = []
|
||||
|
||||
async def d(self):
|
||||
await asyncio.sleep(random() * 0.1)
|
||||
self.log.append("d")
|
||||
|
||||
async def c(self):
|
||||
await asyncio.sleep(random() * 0.1)
|
||||
self.log.append("c")
|
||||
|
||||
async def b(self, c, d):
|
||||
self.log.append("b")
|
||||
|
||||
async def a(self, b, c):
|
||||
self.log.append("a")
|
||||
|
||||
async def go(self, a):
|
||||
self.log.append("go")
|
||||
return self.log
|
||||
|
||||
|
||||
class WithParameters(AsyncBase):
|
||||
async def go(self, calc1, calc2, param1):
|
||||
return param1 + calc1 + calc2
|
||||
|
||||
async def calc1(self):
|
||||
return 5
|
||||
|
||||
async def calc2(self):
|
||||
return 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple():
|
||||
assert await Simple().one() == ["two", "one"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex():
|
||||
result = await Complex().go()
|
||||
# 'c' should only be called once
|
||||
assert tuple(result) in (
|
||||
# c and d could happen in either order
|
||||
("c", "d", "b", "a", "go"),
|
||||
("d", "c", "b", "a", "go"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_parameters():
|
||||
result = await WithParameters().go(param1=4)
|
||||
assert result == 15
|
||||
|
||||
# Should throw an error if that parameter is missing
|
||||
with pytest.raises(AssertionError) as e:
|
||||
await WithParameters().go()
|
||||
assert e.args[0] == (
|
||||
"The following DI parameters could not be "
|
||||
"found in the registry: ['param1']"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue