From 86aaa7c7b2eb0f96bace9ec07ebc3392be58b546 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 16 Nov 2021 13:54:14 -0800 Subject: [PATCH 1/4] New AsyncBase class, refs #878 --- datasette/utils/asyncdi.py | 101 +++++++++++++++++++++++++++++++++++++ tests/test_asyncdi.py | 80 +++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 datasette/utils/asyncdi.py create mode 100644 tests/test_asyncdi.py diff --git a/datasette/utils/asyncdi.py b/datasette/utils/asyncdi.py new file mode 100644 index 00000000..4c7f3a60 --- /dev/null +++ b/datasette/utils/asyncdi.py @@ -0,0 +1,101 @@ +import asyncio +from functools import wraps +import inspect + +try: + import graphlib +except ImportError: + from . import vendored_graphlib as graphlib + + +class AsyncMeta(type): + def __new__(cls, name, bases, attrs): + # Decorate any items that are 'async def' methods + _registry = {} + new_attrs = {"_registry": _registry} + for key, value in attrs.items(): + if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve": + new_attrs[key] = make_method(value) + _registry[key] = new_attrs[key] + else: + new_attrs[key] = value + # Gather graph for later dependency resolution + graph = { + key: { + p + for p in inspect.signature(method).parameters.keys() + if p != "self" and not p.startswith("_") + } + for key, method in _registry.items() + } + new_attrs["_graph"] = graph + return super().__new__(cls, name, bases, new_attrs) + + +def make_method(method): + parameters = inspect.signature(method).parameters.keys() + + @wraps(method) + async def inner(self, _results=None, **kwargs): + # Any parameters not provided by kwargs are resolved from registry + to_resolve = [p for p in parameters if p not in kwargs and p != "self"] + missing = [p for p in to_resolve if p not in self._registry] + assert ( + not missing + ), "The following DI parameters could not be found in the registry: {}".format( + missing + ) + + results = {} + results.update(kwargs) + if to_resolve: + resolved_parameters = await self.resolve(to_resolve, _results) + results.update(resolved_parameters) + return_value = await method(self, **results) + if _results is not None: + _results[method.__name__] = return_value + return return_value + + return inner + + +class AsyncBase(metaclass=AsyncMeta): + async def resolve(self, names, results=None): + if results is None: + results = {} + + # Come up with an execution plan, just for these nodes + ts = graphlib.TopologicalSorter() + to_do = set(names) + done = set() + while to_do: + item = to_do.pop() + dependencies = self._graph[item] + ts.add(item, *dependencies) + done.add(item) + # Add any not-done dependencies to the queue + to_do.update({k for k in dependencies if k not in done}) + + ts.prepare() + plan = [] + while ts.is_active(): + node_group = ts.get_ready() + plan.append(node_group) + ts.done(*node_group) + + results = {} + for node_group in plan: + awaitables = [ + self._registry[name]( + self, + _results=results, + **{k: v for k, v in results.items() if k in self._graph[name]}, + ) + for name in node_group + ] + awaitable_results = await asyncio.gather(*awaitables) + results.update( + {p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)} + ) + + return {key: value for key, value in results.items() if key in names} diff --git a/tests/test_asyncdi.py b/tests/test_asyncdi.py new file mode 100644 index 00000000..68dcb2fd --- /dev/null +++ b/tests/test_asyncdi.py @@ -0,0 +1,80 @@ +import asyncio +from datasette.utils.asyncdi import AsyncBase +import pytest +from random import random + + +class Simple(AsyncBase): + def __init__(self): + self.log = [] + + async def two(self): + self.log.append("two") + + async def one(self, two): + self.log.append("one") + return self.log + + +class Complex(AsyncBase): + def __init__(self): + self.log = [] + + async def d(self): + await asyncio.sleep(random() * 0.1) + self.log.append("d") + + async def c(self): + await asyncio.sleep(random() * 0.1) + self.log.append("c") + + async def b(self, c, d): + self.log.append("b") + + async def a(self, b, c): + self.log.append("a") + + async def go(self, a): + self.log.append("go") + return self.log + + +class WithParameters(AsyncBase): + async def go(self, calc1, calc2, param1): + return param1 + calc1 + calc2 + + async def calc1(self): + return 5 + + async def calc2(self): + return 6 + + +@pytest.mark.asyncio +async def test_simple(): + assert await Simple().one() == ["two", "one"] + + +@pytest.mark.asyncio +async def test_complex(): + result = await Complex().go() + # 'c' should only be called once + assert tuple(result) in ( + # c and d could happen in either order + ("c", "d", "b", "a", "go"), + ("d", "c", "b", "a", "go"), + ) + + +@pytest.mark.asyncio +async def test_with_parameters(): + result = await WithParameters().go(param1=4) + assert result == 15 + + # Should throw an error if that parameter is missing + with pytest.raises(AssertionError) as e: + await WithParameters().go() + assert e.args[0] == ( + "The following DI parameters could not be " + "found in the registry: ['param1']" + ) From 22f41f798360bbc947d70577f4a733a42b22c31f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 16 Nov 2021 14:01:16 -0800 Subject: [PATCH 2/4] @inject decorator or inject_all = True for AsyncBase, refs #878 --- datasette/utils/asyncdi.py | 12 +++++++++++- tests/test_asyncdi.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/datasette/utils/asyncdi.py b/datasette/utils/asyncdi.py index 4c7f3a60..ba743450 100644 --- a/datasette/utils/asyncdi.py +++ b/datasette/utils/asyncdi.py @@ -8,13 +8,23 @@ except ImportError: from . import vendored_graphlib as graphlib +def inject(fn): + fn._inject = True + return fn + + class AsyncMeta(type): def __new__(cls, name, bases, attrs): # Decorate any items that are 'async def' methods _registry = {} new_attrs = {"_registry": _registry} + inject_all = attrs.get("inject_all") for key, value in attrs.items(): - if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve": + if ( + inspect.iscoroutinefunction(value) + and not value.__name__ == "resolve" + and (inject_all or getattr(value, "_inject", None)) + ): new_attrs[key] = make_method(value) _registry[key] = new_attrs[key] else: diff --git a/tests/test_asyncdi.py b/tests/test_asyncdi.py index 68dcb2fd..baf8a469 100644 --- a/tests/test_asyncdi.py +++ b/tests/test_asyncdi.py @@ -1,5 +1,5 @@ import asyncio -from datasette.utils.asyncdi import AsyncBase +from datasette.utils.asyncdi import AsyncBase, inject import pytest from random import random @@ -8,15 +8,22 @@ class Simple(AsyncBase): def __init__(self): self.log = [] + @inject async def two(self): self.log.append("two") + @inject async def one(self, two): self.log.append("one") return self.log + async def not_inject(self, one, two): + return one + two + class Complex(AsyncBase): + inject_all = True + def __init__(self): self.log = [] @@ -40,6 +47,8 @@ class Complex(AsyncBase): class WithParameters(AsyncBase): + inject_all = True + async def go(self, calc1, calc2, param1): return param1 + calc1 + calc2 @@ -53,6 +62,7 @@ class WithParameters(AsyncBase): @pytest.mark.asyncio async def test_simple(): assert await Simple().one() == ["two", "one"] + assert await Simple().not_inject(6, 7) == 13 @pytest.mark.asyncio From f4c5f58887328323335f7a77d9a334774d439325 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 16 Nov 2021 14:03:16 -0800 Subject: [PATCH 3/4] Add vendored graphlib library from Python 3.10 --- datasette/utils/vendored_graphlib.py | 249 +++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 datasette/utils/vendored_graphlib.py diff --git a/datasette/utils/vendored_graphlib.py b/datasette/utils/vendored_graphlib.py new file mode 100644 index 00000000..bec0fdb8 --- /dev/null +++ b/datasette/utils/vendored_graphlib.py @@ -0,0 +1,249 @@ +# Vendored from https://raw.githubusercontent.com/python/cpython/3.10/Lib/graphlib.py +# License: https://github.com/python/cpython/blob/main/LICENSE + +__all__ = ["TopologicalSorter", "CycleError"] + +_NODE_OUT = -1 +_NODE_DONE = -2 + + +class _NodeInfo: + __slots__ = "node", "npredecessors", "successors" + + def __init__(self, node): + # The node this class is augmenting. + self.node = node + + # Number of predecessors, generally >= 0. When this value falls to 0, + # and is returned by get_ready(), this is set to _NODE_OUT and when the + # node is marked done by a call to done(), set to _NODE_DONE. + self.npredecessors = 0 + + # List of successor nodes. The list can contain duplicated elements as + # long as they're all reflected in the successor's npredecessors attribute. + self.successors = [] + + +class CycleError(ValueError): + """Subclass of ValueError raised by TopologicalSorter.prepare if cycles + exist in the working graph. + + If multiple cycles exist, only one undefined choice among them will be reported + and included in the exception. The detected cycle can be accessed via the second + element in the *args* attribute of the exception instance and consists in a list + of nodes, such that each node is, in the graph, an immediate predecessor of the + next node in the list. In the reported list, the first and the last node will be + the same, to make it clear that it is cyclic. + """ + + pass + + +class TopologicalSorter: + """Provides functionality to topologically sort a graph of hashable nodes""" + + def __init__(self, graph=None): + self._node2info = {} + self._ready_nodes = None + self._npassedout = 0 + self._nfinished = 0 + + if graph is not None: + for node, predecessors in graph.items(): + self.add(node, *predecessors) + + def _get_nodeinfo(self, node): + if (result := self._node2info.get(node)) is None: + self._node2info[node] = result = _NodeInfo(node) + return result + + def add(self, node, *predecessors): + """Add a new node and its predecessors to the graph. + + Both the *node* and all elements in *predecessors* must be hashable. + + If called multiple times with the same node argument, the set of dependencies + will be the union of all dependencies passed in. + + It is possible to add a node with no dependencies (*predecessors* is not provided) + as well as provide a dependency twice. If a node that has not been provided before + is included among *predecessors* it will be automatically added to the graph with + no predecessors of its own. + + Raises ValueError if called after "prepare". + """ + if self._ready_nodes is not None: + raise ValueError("Nodes cannot be added after a call to prepare()") + + # Create the node -> predecessor edges + nodeinfo = self._get_nodeinfo(node) + nodeinfo.npredecessors += len(predecessors) + + # Create the predecessor -> node edges + for pred in predecessors: + pred_info = self._get_nodeinfo(pred) + pred_info.successors.append(node) + + def prepare(self): + """Mark the graph as finished and check for cycles in the graph. + + If any cycle is detected, "CycleError" will be raised, but "get_ready" can + still be used to obtain as many nodes as possible until cycles block more + progress. After a call to this function, the graph cannot be modified and + therefore no more nodes can be added using "add". + """ + if self._ready_nodes is not None: + raise ValueError("cannot prepare() more than once") + + self._ready_nodes = [ + i.node for i in self._node2info.values() if i.npredecessors == 0 + ] + # ready_nodes is set before we look for cycles on purpose: + # if the user wants to catch the CycleError, that's fine, + # they can continue using the instance to grab as many + # nodes as possible before cycles block more progress + cycle = self._find_cycle() + if cycle: + raise CycleError(f"nodes are in a cycle", cycle) + + def get_ready(self): + """Return a tuple of all the nodes that are ready. + + Initially it returns all nodes with no predecessors; once those are marked + as processed by calling "done", further calls will return all new nodes that + have all their predecessors already processed. Once no more progress can be made, + empty tuples are returned. + + Raises ValueError if called without calling "prepare" previously. + """ + if self._ready_nodes is None: + raise ValueError("prepare() must be called first") + + # Get the nodes that are ready and mark them + result = tuple(self._ready_nodes) + n2i = self._node2info + for node in result: + n2i[node].npredecessors = _NODE_OUT + + # Clean the list of nodes that are ready and update + # the counter of nodes that we have returned. + self._ready_nodes.clear() + self._npassedout += len(result) + + return result + + def is_active(self): + """Return ``True`` if more progress can be made and ``False`` otherwise. + + Progress can be made if cycles do not block the resolution and either there + are still nodes ready that haven't yet been returned by "get_ready" or the + number of nodes marked "done" is less than the number that have been returned + by "get_ready". + + Raises ValueError if called without calling "prepare" previously. + """ + if self._ready_nodes is None: + raise ValueError("prepare() must be called first") + return self._nfinished < self._npassedout or bool(self._ready_nodes) + + def __bool__(self): + return self.is_active() + + def done(self, *nodes): + """Marks a set of nodes returned by "get_ready" as processed. + + This method unblocks any successor of each node in *nodes* for being returned + in the future by a call to "get_ready". + + Raises :exec:`ValueError` if any node in *nodes* has already been marked as + processed by a previous call to this method, if a node was not added to the + graph by using "add" or if called without calling "prepare" previously or if + node has not yet been returned by "get_ready". + """ + + if self._ready_nodes is None: + raise ValueError("prepare() must be called first") + + n2i = self._node2info + + for node in nodes: + + # Check if we know about this node (it was added previously using add() + if (nodeinfo := n2i.get(node)) is None: + raise ValueError(f"node {node!r} was not added using add()") + + # If the node has not being returned (marked as ready) previously, inform the user. + stat = nodeinfo.npredecessors + if stat != _NODE_OUT: + if stat >= 0: + raise ValueError( + f"node {node!r} was not passed out (still not ready)" + ) + elif stat == _NODE_DONE: + raise ValueError(f"node {node!r} was already marked done") + else: + assert False, f"node {node!r}: unknown status {stat}" + + # Mark the node as processed + nodeinfo.npredecessors = _NODE_DONE + + # Go to all the successors and reduce the number of predecessors, collecting all the ones + # that are ready to be returned in the next get_ready() call. + for successor in nodeinfo.successors: + successor_info = n2i[successor] + successor_info.npredecessors -= 1 + if successor_info.npredecessors == 0: + self._ready_nodes.append(successor) + self._nfinished += 1 + + def _find_cycle(self): + n2i = self._node2info + stack = [] + itstack = [] + seen = set() + node2stacki = {} + + for node in n2i: + if node in seen: + continue + + while True: + if node in seen: + # If we have seen already the node and is in the + # current stack we have found a cycle. + if node in node2stacki: + return stack[node2stacki[node] :] + [node] + # else go on to get next successor + else: + seen.add(node) + itstack.append(iter(n2i[node].successors).__next__) + node2stacki[node] = len(stack) + stack.append(node) + + # Backtrack to the topmost stack entry with + # at least another successor. + while stack: + try: + node = itstack[-1]() + break + except StopIteration: + del node2stacki[stack.pop()] + itstack.pop() + else: + break + return None + + def static_order(self): + """Returns an iterable of nodes in a topological order. + + The particular order that is returned may depend on the specific + order in which the items were inserted in the graph. + + Using this method does not require to call "prepare" or "done". If any + cycle is detected, :exc:`CycleError` will be raised. + """ + self.prepare() + while self.is_active(): + node_group = self.get_ready() + yield from node_group + self.done(*node_group) From 8f757da0750fe7f27b4ed3839bc3ef3650832ad9 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 16 Nov 2021 15:46:50 -0800 Subject: [PATCH 4/4] Backport graphlib to work on Python 3.6, refs #878 --- datasette/utils/vendored_graphlib.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datasette/utils/vendored_graphlib.py b/datasette/utils/vendored_graphlib.py index bec0fdb8..142753f4 100644 --- a/datasette/utils/vendored_graphlib.py +++ b/datasette/utils/vendored_graphlib.py @@ -1,4 +1,5 @@ # Vendored from https://raw.githubusercontent.com/python/cpython/3.10/Lib/graphlib.py +# Modified to work on Python 3.6 (I removed := operator) # License: https://github.com/python/cpython/blob/main/LICENSE __all__ = ["TopologicalSorter", "CycleError"] @@ -53,7 +54,8 @@ class TopologicalSorter: self.add(node, *predecessors) def _get_nodeinfo(self, node): - if (result := self._node2info.get(node)) is None: + result = self._node2info.get(node) + if result is None: self._node2info[node] = result = _NodeInfo(node) return result @@ -169,7 +171,8 @@ class TopologicalSorter: for node in nodes: # Check if we know about this node (it was added previously using add() - if (nodeinfo := n2i.get(node)) is None: + nodeinfo = n2i.get(node) + if nodeinfo is None: raise ValueError(f"node {node!r} was not added using add()") # If the node has not being returned (marked as ready) previously, inform the user.