mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
@inject decorator or inject_all = True for AsyncBase, refs #878
This commit is contained in:
parent
86aaa7c7b2
commit
22f41f7983
2 changed files with 22 additions and 2 deletions
|
|
@ -8,13 +8,23 @@ except ImportError:
|
||||||
from . import vendored_graphlib as graphlib
|
from . import vendored_graphlib as graphlib
|
||||||
|
|
||||||
|
|
||||||
|
def inject(fn):
|
||||||
|
fn._inject = True
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
class AsyncMeta(type):
|
class AsyncMeta(type):
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
# Decorate any items that are 'async def' methods
|
# Decorate any items that are 'async def' methods
|
||||||
_registry = {}
|
_registry = {}
|
||||||
new_attrs = {"_registry": _registry}
|
new_attrs = {"_registry": _registry}
|
||||||
|
inject_all = attrs.get("inject_all")
|
||||||
for key, value in attrs.items():
|
for key, value in attrs.items():
|
||||||
if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
|
if (
|
||||||
|
inspect.iscoroutinefunction(value)
|
||||||
|
and not value.__name__ == "resolve"
|
||||||
|
and (inject_all or getattr(value, "_inject", None))
|
||||||
|
):
|
||||||
new_attrs[key] = make_method(value)
|
new_attrs[key] = make_method(value)
|
||||||
_registry[key] = new_attrs[key]
|
_registry[key] = new_attrs[key]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from datasette.utils.asyncdi import AsyncBase
|
from datasette.utils.asyncdi import AsyncBase, inject
|
||||||
import pytest
|
import pytest
|
||||||
from random import random
|
from random import random
|
||||||
|
|
||||||
|
|
@ -8,15 +8,22 @@ class Simple(AsyncBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log = []
|
self.log = []
|
||||||
|
|
||||||
|
@inject
|
||||||
async def two(self):
|
async def two(self):
|
||||||
self.log.append("two")
|
self.log.append("two")
|
||||||
|
|
||||||
|
@inject
|
||||||
async def one(self, two):
|
async def one(self, two):
|
||||||
self.log.append("one")
|
self.log.append("one")
|
||||||
return self.log
|
return self.log
|
||||||
|
|
||||||
|
async def not_inject(self, one, two):
|
||||||
|
return one + two
|
||||||
|
|
||||||
|
|
||||||
class Complex(AsyncBase):
|
class Complex(AsyncBase):
|
||||||
|
inject_all = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log = []
|
self.log = []
|
||||||
|
|
||||||
|
|
@ -40,6 +47,8 @@ class Complex(AsyncBase):
|
||||||
|
|
||||||
|
|
||||||
class WithParameters(AsyncBase):
|
class WithParameters(AsyncBase):
|
||||||
|
inject_all = True
|
||||||
|
|
||||||
async def go(self, calc1, calc2, param1):
|
async def go(self, calc1, calc2, param1):
|
||||||
return param1 + calc1 + calc2
|
return param1 + calc1 + calc2
|
||||||
|
|
||||||
|
|
@ -53,6 +62,7 @@ class WithParameters(AsyncBase):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simple():
|
async def test_simple():
|
||||||
assert await Simple().one() == ["two", "one"]
|
assert await Simple().one() == ["two", "one"]
|
||||||
|
assert await Simple().not_inject(6, 7) == 13
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue