.add_message() now works inside plugins, closes #864

Refs #870
This commit is contained in:
Simon Willison 2020-06-28 17:25:35 -07:00
commit 7ac4936cec
4 changed files with 32 additions and 17 deletions

View file

@ -5,6 +5,7 @@ import datetime
import hashlib import hashlib
import inspect import inspect
import itertools import itertools
from itsdangerous import BadSignature
import json import json
import os import os
import re import re
@ -926,6 +927,14 @@ class DatasetteRouter:
if base_url != "/" and path.startswith(base_url): if base_url != "/" and path.startswith(base_url):
path = "/" + path[len(base_url) :] path = "/" + path[len(base_url) :]
request = Request(scope, receive) request = Request(scope, receive)
# Populate request_messages if ds_messages cookie is present
try:
request._messages = self.ds.unsign(
request.cookies.get("ds_messages", ""), "messages"
)
except BadSignature:
pass
scope_modifications = {} scope_modifications = {}
# Apply force_https_urls, if set # Apply force_https_urls, if set
if ( if (
@ -952,7 +961,11 @@ class DatasetteRouter:
new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
request.scope = new_scope request.scope = new_scope
try: try:
return await view(request, send) response = await view(request, send)
if response:
self.ds._write_messages_to_response(request, response)
await response.asgi_send(send)
return
except NotFound as exception: except NotFound as exception:
return await self.handle_404(scope, receive, send, exception) return await self.handle_404(scope, receive, send, exception)
except Exception as exception: except Exception as exception:
@ -1099,6 +1112,6 @@ def wrap_view(view_fn, datasette):
datasette=datasette, datasette=datasette,
) )
if response is not None: if response is not None:
await response.asgi_send(send) return response
return async_view_fn return async_view_fn

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import csv import csv
import itertools import itertools
from itsdangerous import BadSignature
import json import json
import re import re
import time import time
@ -82,19 +81,8 @@ class BaseView:
return "ff0000" return "ff0000"
async def dispatch_request(self, request, *args, **kwargs): async def dispatch_request(self, request, *args, **kwargs):
# Populate request_messages if ds_messages cookie is present
if self.ds:
try:
request._messages = self.ds.unsign(
request.cookies.get("ds_messages", ""), "messages"
)
except BadSignature:
pass
handler = getattr(self, request.method.lower(), None) handler = getattr(self, request.method.lower(), None)
response = await handler(request, *args, **kwargs) return await handler(request, *args, **kwargs)
if self.ds:
self.ds._write_messages_to_response(request, response)
return response
async def render(self, templates, request, context=None): async def render(self, templates, request, context=None):
context = context or {} context = context or {}
@ -123,10 +111,9 @@ class BaseView:
def as_view(cls, *class_args, **class_kwargs): def as_view(cls, *class_args, **class_kwargs):
async def view(request, send): async def view(request, send):
self = view.view_class(*class_args, **class_kwargs) self = view.view_class(*class_args, **class_kwargs)
response = await self.dispatch_request( return await self.dispatch_request(
request, **request.scope["url_route"]["kwargs"] request, **request.scope["url_route"]["kwargs"]
) )
await response.asgi_send(send)
view.view_class = cls view.view_class = cls
view.__doc__ = cls.__doc__ view.__doc__ = cls.__doc__

View file

@ -190,6 +190,12 @@ def register_routes():
def not_async(): def not_async():
return Response.html("This was not async") return Response.html("This was not async")
def add_message(datasette, request):
datasette.add_message(request, "Hello from messages")
print("Adding message")
print(request._messages)
return Response.html("Added message")
return [ return [
(r"/one/$", one), (r"/one/$", one),
(r"/two/(?P<name>.*)$", two), (r"/two/(?P<name>.*)$", two),
@ -197,6 +203,7 @@ def register_routes():
(r"/post/$", post), (r"/post/$", post),
(r"/csrftoken-form/$", csrftoken_form), (r"/csrftoken-form/$", csrftoken_form),
(r"/not-async/$", not_async), (r"/not-async/$", not_async),
(r"/add-message/$", add_message),
] ]

View file

@ -602,6 +602,14 @@ def test_register_routes_asgi(app_client):
assert "1" == response.headers["x-three"] assert "1" == response.headers["x-three"]
def test_register_routes_add_message(app_client):
response = app_client.get("/add-message/")
assert 200 == response.status
assert "Added message" == response.text
decoded = app_client.ds.unsign(response.cookies["ds_messages"], "messages")
assert [["Hello from messages", 1]] == decoded
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_startup(app_client): async def test_startup(app_client):
await app_client.ds.invoke_startup() await app_client.ds.invoke_startup()