From 7ac4936cec87f5a591e5d2680f0acefc3d35a705 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 28 Jun 2020 17:25:35 -0700 Subject: [PATCH] .add_message() now works inside plugins, closes #864 Refs #870 --- datasette/app.py | 17 +++++++++++++++-- datasette/views/base.py | 17 ++--------------- tests/plugins/my_plugin.py | 7 +++++++ tests/test_plugins.py | 8 ++++++++ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index d4276af1..af3dcc8b 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -5,6 +5,7 @@ import datetime import hashlib import inspect import itertools +from itsdangerous import BadSignature import json import os import re @@ -926,6 +927,14 @@ class DatasetteRouter: if base_url != "/" and path.startswith(base_url): path = "/" + path[len(base_url) :] request = Request(scope, receive) + # Populate request_messages if ds_messages cookie is present + try: + request._messages = self.ds.unsign( + request.cookies.get("ds_messages", ""), "messages" + ) + except BadSignature: + pass + scope_modifications = {} # Apply force_https_urls, if set if ( @@ -952,7 +961,11 @@ class DatasetteRouter: new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) request.scope = new_scope try: - return await view(request, send) + response = await view(request, send) + if response: + self.ds._write_messages_to_response(request, response) + await response.asgi_send(send) + return except NotFound as exception: return await self.handle_404(scope, receive, send, exception) except Exception as exception: @@ -1099,6 +1112,6 @@ def wrap_view(view_fn, datasette): datasette=datasette, ) if response is not None: - await response.asgi_send(send) + return response return async_view_fn diff --git a/datasette/views/base.py b/datasette/views/base.py index 280ae49d..208c3c96 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,7 +1,6 @@ import asyncio import csv import itertools -from itsdangerous import BadSignature import json import re import time @@ -82,19 +81,8 @@ class BaseView: return "ff0000" async def dispatch_request(self, request, *args, **kwargs): - # Populate request_messages if ds_messages cookie is present - if self.ds: - try: - request._messages = self.ds.unsign( - request.cookies.get("ds_messages", ""), "messages" - ) - except BadSignature: - pass handler = getattr(self, request.method.lower(), None) - response = await handler(request, *args, **kwargs) - if self.ds: - self.ds._write_messages_to_response(request, response) - return response + return await handler(request, *args, **kwargs) async def render(self, templates, request, context=None): context = context or {} @@ -123,10 +111,9 @@ class BaseView: def as_view(cls, *class_args, **class_kwargs): async def view(request, send): self = view.view_class(*class_args, **class_kwargs) - response = await self.dispatch_request( + return await self.dispatch_request( request, **request.scope["url_route"]["kwargs"] ) - await response.asgi_send(send) view.view_class = cls view.__doc__ = cls.__doc__ diff --git a/tests/plugins/my_plugin.py b/tests/plugins/my_plugin.py index e4e4153c..bf6340ce 100644 --- a/tests/plugins/my_plugin.py +++ b/tests/plugins/my_plugin.py @@ -190,6 +190,12 @@ def register_routes(): def not_async(): return Response.html("This was not async") + def add_message(datasette, request): + datasette.add_message(request, "Hello from messages") + print("Adding message") + print(request._messages) + return Response.html("Added message") + return [ (r"/one/$", one), (r"/two/(?P.*)$", two), @@ -197,6 +203,7 @@ def register_routes(): (r"/post/$", post), (r"/csrftoken-form/$", csrftoken_form), (r"/not-async/$", not_async), + (r"/add-message/$", add_message), ] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index b798e52d..9468fde9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -602,6 +602,14 @@ def test_register_routes_asgi(app_client): assert "1" == response.headers["x-three"] +def test_register_routes_add_message(app_client): + response = app_client.get("/add-message/") + assert 200 == response.status + assert "Added message" == response.text + decoded = app_client.ds.unsign(response.cookies["ds_messages"], "messages") + assert [["Hello from messages", 1]] == decoded + + @pytest.mark.asyncio async def test_startup(app_client): await app_client.ds.invoke_startup()