await_me_maybe utility function

This commit is contained in:
Simon Willison 2020-09-02 15:21:12 -07:00
commit 26b2922f17
3 changed files with 20 additions and 34 deletions

View file

@ -45,6 +45,7 @@ from .database import Database, QueryInterrupted
from .utils import ( from .utils import (
async_call_with_supported_arguments, async_call_with_supported_arguments,
await_me_maybe,
call_with_supported_arguments, call_with_supported_arguments,
display_actor, display_actor,
escape_css_string, escape_css_string,
@ -312,10 +313,7 @@ class Datasette:
async def invoke_startup(self): async def invoke_startup(self):
for hook in pm.hook.startup(datasette=self): for hook in pm.hook.startup(datasette=self):
if callable(hook): await await_me_maybe(hook)
hook = hook()
if asyncio.iscoroutine(hook):
hook = await hook
def sign(self, value, namespace="default"): def sign(self, value, namespace="default"):
return URLSafeSerializer(self._secret, namespace).dumps(value) return URLSafeSerializer(self._secret, namespace).dumps(value)
@ -400,10 +398,7 @@ class Datasette:
for more_queries in pm.hook.canned_queries( for more_queries in pm.hook.canned_queries(
datasette=self, database=database_name, actor=actor, datasette=self, database=database_name, actor=actor,
): ):
if callable(more_queries): more_queries = await await_me_maybe(more_queries)
more_queries = more_queries()
if asyncio.iscoroutine(more_queries):
more_queries = await more_queries
queries.update(more_queries or {}) queries.update(more_queries or {})
# Fix any {"name": "select ..."} queries to be {"name": {"sql": "select ..."}} # Fix any {"name": "select ..."} queries to be {"name": {"sql": "select ..."}}
for key in queries: for key in queries:
@ -475,10 +470,7 @@ class Datasette:
for check in pm.hook.permission_allowed( for check in pm.hook.permission_allowed(
datasette=self, actor=actor, action=action, resource=resource, datasette=self, actor=actor, action=action, resource=resource,
): ):
if callable(check): check = await await_me_maybe(check)
check = check()
if asyncio.iscoroutine(check):
check = await check
if check is not None: if check is not None:
result = check result = check
used_default = False used_default = False
@ -718,10 +710,7 @@ class Datasette:
request=request, request=request,
datasette=self, datasette=self,
): ):
if callable(extra_script): extra_script = await await_me_maybe(extra_script)
extra_script = extra_script()
if asyncio.iscoroutine(extra_script):
extra_script = await extra_script
body_scripts.append(Markup(extra_script)) body_scripts.append(Markup(extra_script))
extra_template_vars = {} extra_template_vars = {}
@ -735,10 +724,7 @@ class Datasette:
request=request, request=request,
datasette=self, datasette=self,
): ):
if callable(extra_vars): extra_vars = await await_me_maybe(extra_vars)
extra_vars = extra_vars()
if asyncio.iscoroutine(extra_vars):
extra_vars = await extra_vars
assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( assert isinstance(extra_vars, dict), "extra_vars is of type {}".format(
type(extra_vars) type(extra_vars)
) )
@ -786,10 +772,7 @@ class Datasette:
request=request, request=request,
datasette=self, datasette=self,
): ):
if callable(hook): hook = await await_me_maybe(hook)
hook = hook()
if asyncio.iscoroutine(hook):
hook = await hook
collected.extend(hook) collected.extend(hook)
collected.extend(self.metadata(key) or []) collected.extend(self.metadata(key) or [])
output = [] output = []
@ -981,10 +964,7 @@ class DatasetteRouter:
default_actor = scope.get("actor") or None default_actor = scope.get("actor") or None
actor = None actor = None
for actor in pm.hook.actor_from_request(datasette=self.ds, request=request): for actor in pm.hook.actor_from_request(datasette=self.ds, request=request):
if callable(actor): actor = await await_me_maybe(actor)
actor = actor()
if asyncio.iscoroutine(actor):
actor = await actor
if actor: if actor:
break break
scope_modifications["actor"] = actor or default_actor scope_modifications["actor"] = actor or default_actor
@ -1079,10 +1059,7 @@ class DatasetteRouter:
for custom_response in pm.hook.forbidden( for custom_response in pm.hook.forbidden(
datasette=self.ds, request=request, message=message datasette=self.ds, request=request, message=message
): ):
if callable(custom_response): custom_response = await await_me_maybe(custom_response)
custom_response = custom_response()
if asyncio.iscoroutine(custom_response):
custom_response = await custom_response
if custom_response is not None: if custom_response is not None:
await custom_response.asgi_send(send) await custom_response.asgi_send(send)
return return

View file

@ -1,3 +1,4 @@
import asyncio
from contextlib import contextmanager from contextlib import contextmanager
from collections import OrderedDict from collections import OrderedDict
import base64 import base64
@ -51,6 +52,14 @@ ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so
""" """
async def await_me_maybe(value):
if callable(value):
value = value()
if asyncio.iscoroutine(value):
value = await value
return value
def urlsafe_components(token): def urlsafe_components(token):
"Splits token on commas and URL decodes each component" "Splits token on commas and URL decodes each component"
return [urllib.parse.unquote_plus(b) for b in token.split(",")] return [urllib.parse.unquote_plus(b) for b in token.split(",")]

View file

@ -12,6 +12,7 @@ from datasette import __version__
from datasette.plugins import pm from datasette.plugins import pm
from datasette.database import QueryInterrupted from datasette.database import QueryInterrupted
from datasette.utils import ( from datasette.utils import (
await_me_maybe,
InvalidSql, InvalidSql,
LimitedWriter, LimitedWriter,
call_with_supported_arguments, call_with_supported_arguments,
@ -492,8 +493,7 @@ class DataView(BaseView):
request=request, request=request,
view_name=self.name, view_name=self.name,
) )
if asyncio.iscoroutine(it_can_render): it_can_render = await await_me_maybe(it_can_render)
it_can_render = await it_can_render
if it_can_render: if it_can_render:
renderers[key] = path_with_format( renderers[key] = path_with_format(
request, key, {**url_labels_extra} request, key, {**url_labels_extra}