mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f757da075 | ||
|
|
f4c5f58887 | ||
|
|
22f41f7983 | ||
|
|
86aaa7c7b2 |
3 changed files with 453 additions and 0 deletions
111
datasette/utils/asyncdi.py
Normal file
111
datasette/utils/asyncdi.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
from functools import wraps
|
||||
import inspect
|
||||
|
||||
try:
|
||||
import graphlib
|
||||
except ImportError:
|
||||
from . import vendored_graphlib as graphlib
|
||||
|
||||
|
||||
def inject(fn):
|
||||
fn._inject = True
|
||||
return fn
|
||||
|
||||
|
||||
class AsyncMeta(type):
|
||||
def __new__(cls, name, bases, attrs):
|
||||
# Decorate any items that are 'async def' methods
|
||||
_registry = {}
|
||||
new_attrs = {"_registry": _registry}
|
||||
inject_all = attrs.get("inject_all")
|
||||
for key, value in attrs.items():
|
||||
if (
|
||||
inspect.iscoroutinefunction(value)
|
||||
and not value.__name__ == "resolve"
|
||||
and (inject_all or getattr(value, "_inject", None))
|
||||
):
|
||||
new_attrs[key] = make_method(value)
|
||||
_registry[key] = new_attrs[key]
|
||||
else:
|
||||
new_attrs[key] = value
|
||||
# Gather graph for later dependency resolution
|
||||
graph = {
|
||||
key: {
|
||||
p
|
||||
for p in inspect.signature(method).parameters.keys()
|
||||
if p != "self" and not p.startswith("_")
|
||||
}
|
||||
for key, method in _registry.items()
|
||||
}
|
||||
new_attrs["_graph"] = graph
|
||||
return super().__new__(cls, name, bases, new_attrs)
|
||||
|
||||
|
||||
def make_method(method):
|
||||
parameters = inspect.signature(method).parameters.keys()
|
||||
|
||||
@wraps(method)
|
||||
async def inner(self, _results=None, **kwargs):
|
||||
# Any parameters not provided by kwargs are resolved from registry
|
||||
to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
|
||||
missing = [p for p in to_resolve if p not in self._registry]
|
||||
assert (
|
||||
not missing
|
||||
), "The following DI parameters could not be found in the registry: {}".format(
|
||||
missing
|
||||
)
|
||||
|
||||
results = {}
|
||||
results.update(kwargs)
|
||||
if to_resolve:
|
||||
resolved_parameters = await self.resolve(to_resolve, _results)
|
||||
results.update(resolved_parameters)
|
||||
return_value = await method(self, **results)
|
||||
if _results is not None:
|
||||
_results[method.__name__] = return_value
|
||||
return return_value
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class AsyncBase(metaclass=AsyncMeta):
|
||||
async def resolve(self, names, results=None):
|
||||
if results is None:
|
||||
results = {}
|
||||
|
||||
# Come up with an execution plan, just for these nodes
|
||||
ts = graphlib.TopologicalSorter()
|
||||
to_do = set(names)
|
||||
done = set()
|
||||
while to_do:
|
||||
item = to_do.pop()
|
||||
dependencies = self._graph[item]
|
||||
ts.add(item, *dependencies)
|
||||
done.add(item)
|
||||
# Add any not-done dependencies to the queue
|
||||
to_do.update({k for k in dependencies if k not in done})
|
||||
|
||||
ts.prepare()
|
||||
plan = []
|
||||
while ts.is_active():
|
||||
node_group = ts.get_ready()
|
||||
plan.append(node_group)
|
||||
ts.done(*node_group)
|
||||
|
||||
results = {}
|
||||
for node_group in plan:
|
||||
awaitables = [
|
||||
self._registry[name](
|
||||
self,
|
||||
_results=results,
|
||||
**{k: v for k, v in results.items() if k in self._graph[name]},
|
||||
)
|
||||
for name in node_group
|
||||
]
|
||||
awaitable_results = await asyncio.gather(*awaitables)
|
||||
results.update(
|
||||
{p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)}
|
||||
)
|
||||
|
||||
return {key: value for key, value in results.items() if key in names}
|
||||
252
datasette/utils/vendored_graphlib.py
Normal file
252
datasette/utils/vendored_graphlib.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
# Vendored from https://raw.githubusercontent.com/python/cpython/3.10/Lib/graphlib.py
|
||||
# Modified to work on Python 3.6 (I removed := operator)
|
||||
# License: https://github.com/python/cpython/blob/main/LICENSE
|
||||
|
||||
__all__ = ["TopologicalSorter", "CycleError"]
|
||||
|
||||
_NODE_OUT = -1
|
||||
_NODE_DONE = -2
|
||||
|
||||
|
||||
class _NodeInfo:
|
||||
__slots__ = "node", "npredecessors", "successors"
|
||||
|
||||
def __init__(self, node):
|
||||
# The node this class is augmenting.
|
||||
self.node = node
|
||||
|
||||
# Number of predecessors, generally >= 0. When this value falls to 0,
|
||||
# and is returned by get_ready(), this is set to _NODE_OUT and when the
|
||||
# node is marked done by a call to done(), set to _NODE_DONE.
|
||||
self.npredecessors = 0
|
||||
|
||||
# List of successor nodes. The list can contain duplicated elements as
|
||||
# long as they're all reflected in the successor's npredecessors attribute.
|
||||
self.successors = []
|
||||
|
||||
|
||||
class CycleError(ValueError):
|
||||
"""Subclass of ValueError raised by TopologicalSorter.prepare if cycles
|
||||
exist in the working graph.
|
||||
|
||||
If multiple cycles exist, only one undefined choice among them will be reported
|
||||
and included in the exception. The detected cycle can be accessed via the second
|
||||
element in the *args* attribute of the exception instance and consists in a list
|
||||
of nodes, such that each node is, in the graph, an immediate predecessor of the
|
||||
next node in the list. In the reported list, the first and the last node will be
|
||||
the same, to make it clear that it is cyclic.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TopologicalSorter:
|
||||
"""Provides functionality to topologically sort a graph of hashable nodes"""
|
||||
|
||||
def __init__(self, graph=None):
|
||||
self._node2info = {}
|
||||
self._ready_nodes = None
|
||||
self._npassedout = 0
|
||||
self._nfinished = 0
|
||||
|
||||
if graph is not None:
|
||||
for node, predecessors in graph.items():
|
||||
self.add(node, *predecessors)
|
||||
|
||||
def _get_nodeinfo(self, node):
|
||||
result = self._node2info.get(node)
|
||||
if result is None:
|
||||
self._node2info[node] = result = _NodeInfo(node)
|
||||
return result
|
||||
|
||||
def add(self, node, *predecessors):
|
||||
"""Add a new node and its predecessors to the graph.
|
||||
|
||||
Both the *node* and all elements in *predecessors* must be hashable.
|
||||
|
||||
If called multiple times with the same node argument, the set of dependencies
|
||||
will be the union of all dependencies passed in.
|
||||
|
||||
It is possible to add a node with no dependencies (*predecessors* is not provided)
|
||||
as well as provide a dependency twice. If a node that has not been provided before
|
||||
is included among *predecessors* it will be automatically added to the graph with
|
||||
no predecessors of its own.
|
||||
|
||||
Raises ValueError if called after "prepare".
|
||||
"""
|
||||
if self._ready_nodes is not None:
|
||||
raise ValueError("Nodes cannot be added after a call to prepare()")
|
||||
|
||||
# Create the node -> predecessor edges
|
||||
nodeinfo = self._get_nodeinfo(node)
|
||||
nodeinfo.npredecessors += len(predecessors)
|
||||
|
||||
# Create the predecessor -> node edges
|
||||
for pred in predecessors:
|
||||
pred_info = self._get_nodeinfo(pred)
|
||||
pred_info.successors.append(node)
|
||||
|
||||
def prepare(self):
|
||||
"""Mark the graph as finished and check for cycles in the graph.
|
||||
|
||||
If any cycle is detected, "CycleError" will be raised, but "get_ready" can
|
||||
still be used to obtain as many nodes as possible until cycles block more
|
||||
progress. After a call to this function, the graph cannot be modified and
|
||||
therefore no more nodes can be added using "add".
|
||||
"""
|
||||
if self._ready_nodes is not None:
|
||||
raise ValueError("cannot prepare() more than once")
|
||||
|
||||
self._ready_nodes = [
|
||||
i.node for i in self._node2info.values() if i.npredecessors == 0
|
||||
]
|
||||
# ready_nodes is set before we look for cycles on purpose:
|
||||
# if the user wants to catch the CycleError, that's fine,
|
||||
# they can continue using the instance to grab as many
|
||||
# nodes as possible before cycles block more progress
|
||||
cycle = self._find_cycle()
|
||||
if cycle:
|
||||
raise CycleError(f"nodes are in a cycle", cycle)
|
||||
|
||||
def get_ready(self):
|
||||
"""Return a tuple of all the nodes that are ready.
|
||||
|
||||
Initially it returns all nodes with no predecessors; once those are marked
|
||||
as processed by calling "done", further calls will return all new nodes that
|
||||
have all their predecessors already processed. Once no more progress can be made,
|
||||
empty tuples are returned.
|
||||
|
||||
Raises ValueError if called without calling "prepare" previously.
|
||||
"""
|
||||
if self._ready_nodes is None:
|
||||
raise ValueError("prepare() must be called first")
|
||||
|
||||
# Get the nodes that are ready and mark them
|
||||
result = tuple(self._ready_nodes)
|
||||
n2i = self._node2info
|
||||
for node in result:
|
||||
n2i[node].npredecessors = _NODE_OUT
|
||||
|
||||
# Clean the list of nodes that are ready and update
|
||||
# the counter of nodes that we have returned.
|
||||
self._ready_nodes.clear()
|
||||
self._npassedout += len(result)
|
||||
|
||||
return result
|
||||
|
||||
def is_active(self):
|
||||
"""Return ``True`` if more progress can be made and ``False`` otherwise.
|
||||
|
||||
Progress can be made if cycles do not block the resolution and either there
|
||||
are still nodes ready that haven't yet been returned by "get_ready" or the
|
||||
number of nodes marked "done" is less than the number that have been returned
|
||||
by "get_ready".
|
||||
|
||||
Raises ValueError if called without calling "prepare" previously.
|
||||
"""
|
||||
if self._ready_nodes is None:
|
||||
raise ValueError("prepare() must be called first")
|
||||
return self._nfinished < self._npassedout or bool(self._ready_nodes)
|
||||
|
||||
def __bool__(self):
|
||||
return self.is_active()
|
||||
|
||||
def done(self, *nodes):
|
||||
"""Marks a set of nodes returned by "get_ready" as processed.
|
||||
|
||||
This method unblocks any successor of each node in *nodes* for being returned
|
||||
in the future by a call to "get_ready".
|
||||
|
||||
Raises :exec:`ValueError` if any node in *nodes* has already been marked as
|
||||
processed by a previous call to this method, if a node was not added to the
|
||||
graph by using "add" or if called without calling "prepare" previously or if
|
||||
node has not yet been returned by "get_ready".
|
||||
"""
|
||||
|
||||
if self._ready_nodes is None:
|
||||
raise ValueError("prepare() must be called first")
|
||||
|
||||
n2i = self._node2info
|
||||
|
||||
for node in nodes:
|
||||
|
||||
# Check if we know about this node (it was added previously using add()
|
||||
nodeinfo = n2i.get(node)
|
||||
if nodeinfo is None:
|
||||
raise ValueError(f"node {node!r} was not added using add()")
|
||||
|
||||
# If the node has not being returned (marked as ready) previously, inform the user.
|
||||
stat = nodeinfo.npredecessors
|
||||
if stat != _NODE_OUT:
|
||||
if stat >= 0:
|
||||
raise ValueError(
|
||||
f"node {node!r} was not passed out (still not ready)"
|
||||
)
|
||||
elif stat == _NODE_DONE:
|
||||
raise ValueError(f"node {node!r} was already marked done")
|
||||
else:
|
||||
assert False, f"node {node!r}: unknown status {stat}"
|
||||
|
||||
# Mark the node as processed
|
||||
nodeinfo.npredecessors = _NODE_DONE
|
||||
|
||||
# Go to all the successors and reduce the number of predecessors, collecting all the ones
|
||||
# that are ready to be returned in the next get_ready() call.
|
||||
for successor in nodeinfo.successors:
|
||||
successor_info = n2i[successor]
|
||||
successor_info.npredecessors -= 1
|
||||
if successor_info.npredecessors == 0:
|
||||
self._ready_nodes.append(successor)
|
||||
self._nfinished += 1
|
||||
|
||||
def _find_cycle(self):
|
||||
n2i = self._node2info
|
||||
stack = []
|
||||
itstack = []
|
||||
seen = set()
|
||||
node2stacki = {}
|
||||
|
||||
for node in n2i:
|
||||
if node in seen:
|
||||
continue
|
||||
|
||||
while True:
|
||||
if node in seen:
|
||||
# If we have seen already the node and is in the
|
||||
# current stack we have found a cycle.
|
||||
if node in node2stacki:
|
||||
return stack[node2stacki[node] :] + [node]
|
||||
# else go on to get next successor
|
||||
else:
|
||||
seen.add(node)
|
||||
itstack.append(iter(n2i[node].successors).__next__)
|
||||
node2stacki[node] = len(stack)
|
||||
stack.append(node)
|
||||
|
||||
# Backtrack to the topmost stack entry with
|
||||
# at least another successor.
|
||||
while stack:
|
||||
try:
|
||||
node = itstack[-1]()
|
||||
break
|
||||
except StopIteration:
|
||||
del node2stacki[stack.pop()]
|
||||
itstack.pop()
|
||||
else:
|
||||
break
|
||||
return None
|
||||
|
||||
def static_order(self):
|
||||
"""Returns an iterable of nodes in a topological order.
|
||||
|
||||
The particular order that is returned may depend on the specific
|
||||
order in which the items were inserted in the graph.
|
||||
|
||||
Using this method does not require to call "prepare" or "done". If any
|
||||
cycle is detected, :exc:`CycleError` will be raised.
|
||||
"""
|
||||
self.prepare()
|
||||
while self.is_active():
|
||||
node_group = self.get_ready()
|
||||
yield from node_group
|
||||
self.done(*node_group)
|
||||
90
tests/test_asyncdi.py
Normal file
90
tests/test_asyncdi.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import asyncio
|
||||
from datasette.utils.asyncdi import AsyncBase, inject
|
||||
import pytest
|
||||
from random import random
|
||||
|
||||
|
||||
class Simple(AsyncBase):
|
||||
def __init__(self):
|
||||
self.log = []
|
||||
|
||||
@inject
|
||||
async def two(self):
|
||||
self.log.append("two")
|
||||
|
||||
@inject
|
||||
async def one(self, two):
|
||||
self.log.append("one")
|
||||
return self.log
|
||||
|
||||
async def not_inject(self, one, two):
|
||||
return one + two
|
||||
|
||||
|
||||
class Complex(AsyncBase):
|
||||
inject_all = True
|
||||
|
||||
def __init__(self):
|
||||
self.log = []
|
||||
|
||||
async def d(self):
|
||||
await asyncio.sleep(random() * 0.1)
|
||||
self.log.append("d")
|
||||
|
||||
async def c(self):
|
||||
await asyncio.sleep(random() * 0.1)
|
||||
self.log.append("c")
|
||||
|
||||
async def b(self, c, d):
|
||||
self.log.append("b")
|
||||
|
||||
async def a(self, b, c):
|
||||
self.log.append("a")
|
||||
|
||||
async def go(self, a):
|
||||
self.log.append("go")
|
||||
return self.log
|
||||
|
||||
|
||||
class WithParameters(AsyncBase):
|
||||
inject_all = True
|
||||
|
||||
async def go(self, calc1, calc2, param1):
|
||||
return param1 + calc1 + calc2
|
||||
|
||||
async def calc1(self):
|
||||
return 5
|
||||
|
||||
async def calc2(self):
|
||||
return 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple():
|
||||
assert await Simple().one() == ["two", "one"]
|
||||
assert await Simple().not_inject(6, 7) == 13
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex():
|
||||
result = await Complex().go()
|
||||
# 'c' should only be called once
|
||||
assert tuple(result) in (
|
||||
# c and d could happen in either order
|
||||
("c", "d", "b", "a", "go"),
|
||||
("d", "c", "b", "a", "go"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_parameters():
|
||||
result = await WithParameters().go(param1=4)
|
||||
assert result == 15
|
||||
|
||||
# Should throw an error if that parameter is missing
|
||||
with pytest.raises(AssertionError) as e:
|
||||
await WithParameters().go()
|
||||
assert e.args[0] == (
|
||||
"The following DI parameters could not be "
|
||||
"found in the registry: ['param1']"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue