Implemented HEAD requests, removed Sanic InvalidUsage

This commit is contained in:
Simon Willison 2019-06-23 15:06:43 -07:00
commit 79950c9643
4 changed files with 15 additions and 9 deletions

View file

@ -12,7 +12,7 @@ from pathlib import Path
import click import click
from markupsafe import Markup from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic.exceptions import InvalidUsage, NotFound from sanic.exceptions import NotFound
from .views.base import DatasetteError, ureg, AsgiRouter from .views.base import DatasetteError, ureg, AsgiRouter
from .views.database import DatabaseDownload, DatabaseView from .views.database import DatabaseDownload, DatabaseView
@ -665,10 +665,6 @@ class Datasette:
status = 404 status = 404
info = {} info = {}
message = exception.args[0] message = exception.args[0]
elif isinstance(exception, InvalidUsage):
status = 405
info = {}
message = exception.args[0]
elif isinstance(exception, DatasetteError): elif isinstance(exception, DatasetteError):
status = exception.status status = exception.status
info = exception.error_dict info = exception.error_dict

View file

@ -54,6 +54,11 @@ class DatasetteError(Exception):
class BaseView(AsgiView): class BaseView(AsgiView):
ds = None ds = None
async def head(self, *args, **kwargs):
response = await self.get(*args, **kwargs)
response.body = b""
return response
def _asset_urls(self, key, template, context): def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins: # Flatten list-of-lists from plugins:
seen_urls = set() seen_urls = set()

View file

@ -37,10 +37,10 @@ class TestClient:
self.asgi_app = asgi_app self.asgi_app = asgi_app
@async_to_sync @async_to_sync
async def get(self, path, allow_redirects=True, redirect_count=0): async def get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
return await self._get(path, allow_redirects, redirect_count) return await self._get(path, allow_redirects, redirect_count, method)
async def _get(self, path, allow_redirects=True, redirect_count=0): async def _get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
query_string = b"" query_string = b""
if "?" in path: if "?" in path:
path, _, query_string = path.partition("?") path, _, query_string = path.partition("?")
@ -50,7 +50,7 @@ class TestClient:
{ {
"type": "http", "type": "http",
"http_version": "1.0", "http_version": "1.0",
"method": "GET", "method": method,
"path": unquote(path), "path": unquote(path),
"raw_path": path.encode("ascii"), "raw_path": path.encode("ascii"),
"query_string": query_string, "query_string": query_string,

View file

@ -46,6 +46,11 @@ def test_homepage(app_client_two_attached_databases):
] == table_links ] == table_links
def test_http_head(app_client):
response = app_client.get("/", method="HEAD")
assert response.status == 200
def test_static(app_client): def test_static(app_client):
response = app_client.get("/-/static/app2.css") response = app_client.get("/-/static/app2.css")
assert response.status == 404 assert response.status == 404