diff --git a/datasette/app.py b/datasette/app.py index e131ba46..941794f9 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -739,6 +739,7 @@ class Datasette: "extra_css_urls": self._asset_urls("extra_css_urls", template, context), "extra_js_urls": self._asset_urls("extra_js_urls", template, context), "base_url": self.config("base_url"), + "csrftoken": request.scope["csrftoken"] if request else lambda: "", }, **extra_template_vars, } diff --git a/datasette/views/base.py b/datasette/views/base.py index f327c6cd..f14e6d3a 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -103,7 +103,6 @@ class BaseView(AsgiView): **context, **{ "database_url": self.database_url, - "csrftoken": request.scope["csrftoken"], "database_color": self.database_color, "show_messages": lambda: self.ds._show_messages(request), "select_templates": [ diff --git a/tests/plugins/my_plugin.py b/tests/plugins/my_plugin.py index 7ed26908..35d7eb54 100644 --- a/tests/plugins/my_plugin.py +++ b/tests/plugins/my_plugin.py @@ -182,11 +182,17 @@ def register_routes(): else: return Response.json(await request.post_vars()) + async def csrftoken_form(request, datasette): + return Response.html( + await datasette.render_template("csrftoken_form.html", request=request) + ) + return [ (r"/one/$", one), (r"/two/(?P.*)$", two), (r"/three/$", three), (r"/post/$", post), + (r"/csrftoken-form/$", csrftoken_form), ] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 4f44430e..2c114850 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -580,6 +580,18 @@ def test_register_routes_post(app_client): assert "post data" == response.json["this is"] +def test_register_routes_csrftoken(tmpdir): + templates = tmpdir / "templates" + templates.mkdir() + (templates / "csrftoken_form.html").write_text( + "CSRFTOKEN: {{ csrftoken() }}", "utf-8" + ) + with make_app_client(template_dir=templates) as client: + response = client.get("/csrftoken-form/") + expected_token = client.ds._last_request.scope["csrftoken"]() + assert "CSRFTOKEN: {}".format(expected_token) == response.text + + def test_register_routes_asgi(app_client): response = app_client.get("/three/") assert {"hello": "world"} == response.json