Re-implemented redirect on 404 with trailing slash, refs #272

All of the tests now pass
This commit is contained in:
Simon Willison 2019-06-23 07:55:55 -07:00
commit 5bd510b01a
3 changed files with 23 additions and 13 deletions

View file

@ -36,7 +36,7 @@ from .utils import (
sqlite_timelimit, sqlite_timelimit,
to_css_class, to_css_class,
) )
from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json from .utils.asgi import asgi_static, asgi_send_html, asgi_send_json, asgi_send_redirect
from .tracer import capture_traces, trace from .tracer import capture_traces, trace
from .plugins import pm, DEFAULT_PLUGINS from .plugins import pm, DEFAULT_PLUGINS
from .version import __version__ from .version import __version__
@ -652,6 +652,17 @@ class Datasette:
outer_self = self outer_self = self
class DatasetteRouter(AsgiRouter): class DatasetteRouter(AsgiRouter):
async def handle_404(self, scope, receive, send):
# If URL has a trailing slash, redirect to URL without it
path = scope.get("raw_path", scope["path"].encode("utf8"))
if path.endswith(b"/"):
path = path.rstrip(b"/")
if scope["query_string"]:
path += b"?" + scope["query_string"]
await asgi_send_redirect(send, path.decode("latin1"))
else:
await super().handle_404(scope, receive, send)
async def handle_500(self, scope, receive, send, exception): async def handle_500(self, scope, receive, send, exception):
title = None title = None
help = None help = None
@ -693,17 +704,6 @@ class Datasette:
) )
app = DatasetteRouter(routes) app = DatasetteRouter(routes)
# On 404 with a trailing slash redirect to path without that slash:
# pylint: disable=unused-variable
# TODO: re-enable this
# @app.middleware("response")
# def redirect_on_404_with_trailing_slash(request, original_response):
# if original_response.status == 404 and request.path.endswith("/"):
# path = request.path.rstrip("/")
# if request.query_string:
# path = "{}?{}".format(path, request.query_string)
# return response.redirect(path)
# First time server starts up, calculate table counts for immutable databases # First time server starts up, calculate table counts for immutable databases
# TODO: re-enable this mechanism # TODO: re-enable this mechanism
# @app.listener("before_server_start") # @app.listener("before_server_start")

View file

@ -1,6 +1,7 @@
import json import json
from mimetypes import guess_type from mimetypes import guess_type
from sanic.views import HTTPMethodView from sanic.views import HTTPMethodView
from sanic.request import Request as SanicRequest
from pathlib import Path from pathlib import Path
import re import re
import aiofiles import aiofiles
@ -166,6 +167,16 @@ async def asgi_send_html(send, html, status=200, headers=None):
) )
async def asgi_send_redirect(send, location, status=302):
await asgi_send(
send,
"",
status=status,
headers={"Location": location},
content_type="text/html",
)
async def asgi_send(send, content, status, headers, content_type="text/plain"): async def asgi_send(send, content, status, headers, content_type="text/plain"):
await asgi_start(send, status, headers, content_type) await asgi_start(send, status, headers, content_type)
await send({"type": "http.response.body", "body": content.encode("utf8")}) await send({"type": "http.response.body", "body": content.encode("utf8")})

View file

@ -9,7 +9,6 @@ import jinja2
import pint import pint
from sanic import response from sanic import response
from sanic.exceptions import NotFound from sanic.exceptions import NotFound
from sanic.request import Request as SanicRequest
from html import escape from html import escape