mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Compare commits
34 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b794554a26 | ||
|
|
eba15fb5a8 | ||
|
|
5e12239402 | ||
|
|
d0fc117693 | ||
|
|
176dd4f12a | ||
|
|
3c4d4f3535 | ||
|
|
979ae4f916 | ||
|
|
1e0998ed2d | ||
|
|
79950c9643 | ||
|
|
620f0aa4f8 | ||
|
|
28c31b228d | ||
|
|
b1c6db4b8f | ||
|
|
1e8419bde4 | ||
|
|
1208bcbfe8 | ||
|
|
4b6b409d85 | ||
|
|
d60fbfcae2 | ||
|
|
cbd0c014ec | ||
|
|
3bd5e14bc1 | ||
|
|
b97cd53a48 | ||
|
|
5bd510b01a | ||
|
|
d2daa1b9f7 | ||
|
|
2b5a644dd7 | ||
|
|
b7a00dbde3 | ||
|
|
ff9efa668e | ||
|
|
eb06e59332 | ||
|
|
8a1a15d725 | ||
|
|
ca03940f6d | ||
|
|
d8dcc34e36 | ||
|
|
55fc993667 | ||
|
|
b53a75c460 | ||
|
|
180d5be811 | ||
|
|
39d66f17c1 | ||
|
|
d736411699 | ||
|
|
7cdc55c683 |
19 changed files with 1510 additions and 947 deletions
249
datasette/app.py
249
datasette/app.py
|
|
@ -1,11 +1,9 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import urllib.parse
|
||||
from concurrent import futures
|
||||
|
|
@ -14,10 +12,8 @@ from pathlib import Path
|
|||
import click
|
||||
from markupsafe import Markup
|
||||
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
|
||||
from sanic import Sanic, response
|
||||
from sanic.exceptions import InvalidUsage, NotFound
|
||||
|
||||
from .views.base import DatasetteError, ureg
|
||||
from .views.base import DatasetteError, ureg, AsgiRouter
|
||||
from .views.database import DatabaseDownload, DatabaseView
|
||||
from .views.index import IndexView
|
||||
from .views.special import JsonDataView
|
||||
|
|
@ -36,7 +32,16 @@ from .utils import (
|
|||
sqlite_timelimit,
|
||||
to_css_class,
|
||||
)
|
||||
from .tracer import capture_traces, trace
|
||||
from .utils.asgi import (
|
||||
AsgiLifespan,
|
||||
NotFound,
|
||||
asgi_static,
|
||||
asgi_send,
|
||||
asgi_send_html,
|
||||
asgi_send_json,
|
||||
asgi_send_redirect,
|
||||
)
|
||||
from .tracer import trace, AsgiTracer
|
||||
from .plugins import pm, DEFAULT_PLUGINS
|
||||
from .version import __version__
|
||||
|
||||
|
|
@ -126,8 +131,8 @@ CONFIG_OPTIONS = (
|
|||
DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
|
||||
|
||||
|
||||
async def favicon(request):
|
||||
return response.text("")
|
||||
async def favicon(scope, receive, send):
|
||||
await asgi_send(send, "", 200)
|
||||
|
||||
|
||||
class Datasette:
|
||||
|
|
@ -413,6 +418,7 @@ class Datasette:
|
|||
"full": sys.version,
|
||||
},
|
||||
"datasette": datasette_version,
|
||||
"asgi": "3.0",
|
||||
"sqlite": {
|
||||
"version": sqlite_version,
|
||||
"fts_versions": fts_versions,
|
||||
|
|
@ -543,21 +549,7 @@ class Datasette:
|
|||
self.renderers[renderer["extension"]] = renderer["callback"]
|
||||
|
||||
def app(self):
|
||||
class TracingSanic(Sanic):
|
||||
async def handle_request(self, request, write_callback, stream_callback):
|
||||
if request.args.get("_trace"):
|
||||
request["traces"] = []
|
||||
request["trace_start"] = time.time()
|
||||
with capture_traces(request["traces"]):
|
||||
await super().handle_request(
|
||||
request, write_callback, stream_callback
|
||||
)
|
||||
else:
|
||||
await super().handle_request(
|
||||
request, write_callback, stream_callback
|
||||
)
|
||||
|
||||
app = TracingSanic(__name__)
|
||||
"Returns an ASGI app function that serves the whole of Datasette"
|
||||
default_templates = str(app_root / "datasette" / "templates")
|
||||
template_paths = []
|
||||
if self.template_dir:
|
||||
|
|
@ -588,134 +580,127 @@ class Datasette:
|
|||
pm.hook.prepare_jinja2_environment(env=self.jinja_env)
|
||||
|
||||
self.register_renderers()
|
||||
|
||||
routes = []
|
||||
|
||||
def add_route(view, regex):
|
||||
routes.append((regex, view))
|
||||
|
||||
# Generate a regex snippet to match all registered renderer file extensions
|
||||
renderer_regex = "|".join(r"\." + key for key in self.renderers.keys())
|
||||
|
||||
app.add_route(IndexView.as_view(self), r"/<as_format:(\.jsono?)?$>")
|
||||
add_route(IndexView.as_asgi(self), r"/(?P<as_format>(\.jsono?)?$)")
|
||||
# TODO: /favicon.ico and /-/static/ deserve far-future cache expires
|
||||
app.add_route(favicon, "/favicon.ico")
|
||||
app.static("/-/static/", str(app_root / "datasette" / "static"))
|
||||
add_route(favicon, "/favicon.ico")
|
||||
|
||||
add_route(
|
||||
asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P<path>.*)$"
|
||||
)
|
||||
for path, dirname in self.static_mounts:
|
||||
app.static(path, dirname)
|
||||
add_route(asgi_static(dirname), r"/" + path + "/(?P<path>.*)$")
|
||||
|
||||
# Mount any plugin static/ directories
|
||||
for plugin in get_plugins(pm):
|
||||
if plugin["static_path"]:
|
||||
modpath = "/-/static-plugins/{}/".format(plugin["name"])
|
||||
app.static(modpath, plugin["static_path"])
|
||||
app.add_route(
|
||||
JsonDataView.as_view(self, "metadata.json", lambda: self._metadata),
|
||||
r"/-/metadata<as_format:(\.json)?$>",
|
||||
modpath = "/-/static-plugins/{}/(?P<path>.*)$".format(plugin["name"])
|
||||
add_route(asgi_static(plugin["static_path"]), modpath)
|
||||
add_route(
|
||||
JsonDataView.as_asgi(self, "metadata.json", lambda: self._metadata),
|
||||
r"/-/metadata(?P<as_format>(\.json)?)$",
|
||||
)
|
||||
app.add_route(
|
||||
JsonDataView.as_view(self, "versions.json", self.versions),
|
||||
r"/-/versions<as_format:(\.json)?$>",
|
||||
add_route(
|
||||
JsonDataView.as_asgi(self, "versions.json", self.versions),
|
||||
r"/-/versions(?P<as_format>(\.json)?)$",
|
||||
)
|
||||
app.add_route(
|
||||
JsonDataView.as_view(self, "plugins.json", self.plugins),
|
||||
r"/-/plugins<as_format:(\.json)?$>",
|
||||
add_route(
|
||||
JsonDataView.as_asgi(self, "plugins.json", self.plugins),
|
||||
r"/-/plugins(?P<as_format>(\.json)?)$",
|
||||
)
|
||||
app.add_route(
|
||||
JsonDataView.as_view(self, "config.json", lambda: self._config),
|
||||
r"/-/config<as_format:(\.json)?$>",
|
||||
add_route(
|
||||
JsonDataView.as_asgi(self, "config.json", lambda: self._config),
|
||||
r"/-/config(?P<as_format>(\.json)?)$",
|
||||
)
|
||||
app.add_route(
|
||||
JsonDataView.as_view(self, "databases.json", self.connected_databases),
|
||||
r"/-/databases<as_format:(\.json)?$>",
|
||||
add_route(
|
||||
JsonDataView.as_asgi(self, "databases.json", self.connected_databases),
|
||||
r"/-/databases(?P<as_format>(\.json)?)$",
|
||||
)
|
||||
app.add_route(
|
||||
DatabaseDownload.as_view(self), r"/<db_name:[^/]+?><as_db:(\.db)$>"
|
||||
add_route(
|
||||
DatabaseDownload.as_asgi(self), r"/(?P<db_name>[^/]+?)(?P<as_db>\.db)$"
|
||||
)
|
||||
app.add_route(
|
||||
DatabaseView.as_view(self),
|
||||
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
|
||||
)
|
||||
app.add_route(
|
||||
TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
|
||||
)
|
||||
app.add_route(
|
||||
RowView.as_view(self),
|
||||
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
|
||||
add_route(
|
||||
DatabaseView.as_asgi(self),
|
||||
r"/(?P<db_name>[^/]+?)(?P<as_format>"
|
||||
+ renderer_regex
|
||||
+ r")?$>",
|
||||
+ r"|.jsono|\.csv)?$",
|
||||
)
|
||||
add_route(
|
||||
TableView.as_asgi(self),
|
||||
r"/(?P<db_name>[^/]+)/(?P<table_and_format>[^/]+?$)",
|
||||
)
|
||||
add_route(
|
||||
RowView.as_asgi(self),
|
||||
r"/(?P<db_name>[^/]+)/(?P<table>[^/]+?)/(?P<pk_path>[^/]+?)(?P<as_format>"
|
||||
+ renderer_regex
|
||||
+ r")?$",
|
||||
)
|
||||
self.register_custom_units()
|
||||
|
||||
# On 404 with a trailing slash redirect to path without that slash:
|
||||
# pylint: disable=unused-variable
|
||||
@app.middleware("response")
|
||||
def redirect_on_404_with_trailing_slash(request, original_response):
|
||||
if original_response.status == 404 and request.path.endswith("/"):
|
||||
path = request.path.rstrip("/")
|
||||
if request.query_string:
|
||||
path = "{}?{}".format(path, request.query_string)
|
||||
return response.redirect(path)
|
||||
|
||||
@app.middleware("response")
|
||||
async def add_traces_to_response(request, response):
|
||||
if request.get("traces") is None:
|
||||
return
|
||||
traces = request["traces"]
|
||||
trace_info = {
|
||||
"request_duration_ms": 1000 * (time.time() - request["trace_start"]),
|
||||
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
|
||||
"num_traces": len(traces),
|
||||
"traces": traces,
|
||||
}
|
||||
if "text/html" in response.content_type and b"</body>" in response.body:
|
||||
extra = json.dumps(trace_info, indent=2)
|
||||
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
|
||||
response.body = response.body.replace(b"</body>", extra_html)
|
||||
elif "json" in response.content_type and response.body.startswith(b"{"):
|
||||
data = json.loads(response.body.decode("utf8"))
|
||||
if "_trace" not in data:
|
||||
data["_trace"] = trace_info
|
||||
response.body = json.dumps(data).encode("utf8")
|
||||
|
||||
@app.exception(Exception)
|
||||
def on_exception(request, exception):
|
||||
title = None
|
||||
help = None
|
||||
if isinstance(exception, NotFound):
|
||||
status = 404
|
||||
info = {}
|
||||
message = exception.args[0]
|
||||
elif isinstance(exception, InvalidUsage):
|
||||
status = 405
|
||||
info = {}
|
||||
message = exception.args[0]
|
||||
elif isinstance(exception, DatasetteError):
|
||||
status = exception.status
|
||||
info = exception.error_dict
|
||||
message = exception.message
|
||||
if exception.messagge_is_html:
|
||||
message = Markup(message)
|
||||
title = exception.title
|
||||
else:
|
||||
status = 500
|
||||
info = {}
|
||||
message = str(exception)
|
||||
traceback.print_exc()
|
||||
templates = ["500.html"]
|
||||
if status != 500:
|
||||
templates = ["{}.html".format(status)] + templates
|
||||
info.update(
|
||||
{"ok": False, "error": message, "status": status, "title": title}
|
||||
)
|
||||
if request is not None and request.path.split("?")[0].endswith(".json"):
|
||||
r = response.json(info, status=status)
|
||||
|
||||
else:
|
||||
template = self.jinja_env.select_template(templates)
|
||||
r = response.html(template.render(info), status=status)
|
||||
if self.cors:
|
||||
r.headers["Access-Control-Allow-Origin"] = "*"
|
||||
return r
|
||||
|
||||
# First time server starts up, calculate table counts for immutable databases
|
||||
@app.listener("before_server_start")
|
||||
async def setup_db(app, loop):
|
||||
async def setup_db():
|
||||
# First time server starts up, calculate table counts for immutable databases
|
||||
for dbname, database in self.databases.items():
|
||||
if not database.is_mutable:
|
||||
await database.table_counts(limit=60 * 60 * 1000)
|
||||
|
||||
return app
|
||||
return AsgiLifespan(
|
||||
AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db
|
||||
)
|
||||
|
||||
|
||||
class DatasetteRouter(AsgiRouter):
|
||||
def __init__(self, datasette, routes):
|
||||
self.ds = datasette
|
||||
super().__init__(routes)
|
||||
|
||||
async def handle_404(self, scope, receive, send):
|
||||
# If URL has a trailing slash, redirect to URL without it
|
||||
path = scope.get("raw_path", scope["path"].encode("utf8"))
|
||||
if path.endswith(b"/"):
|
||||
path = path.rstrip(b"/")
|
||||
if scope["query_string"]:
|
||||
path += b"?" + scope["query_string"]
|
||||
await asgi_send_redirect(send, path.decode("latin1"))
|
||||
else:
|
||||
await super().handle_404(scope, receive, send)
|
||||
|
||||
async def handle_500(self, scope, receive, send, exception):
|
||||
title = None
|
||||
if isinstance(exception, NotFound):
|
||||
status = 404
|
||||
info = {}
|
||||
message = exception.args[0]
|
||||
elif isinstance(exception, DatasetteError):
|
||||
status = exception.status
|
||||
info = exception.error_dict
|
||||
message = exception.message
|
||||
if exception.messagge_is_html:
|
||||
message = Markup(message)
|
||||
title = exception.title
|
||||
else:
|
||||
status = 500
|
||||
info = {}
|
||||
message = str(exception)
|
||||
traceback.print_exc()
|
||||
templates = ["500.html"]
|
||||
if status != 500:
|
||||
templates = ["{}.html".format(status)] + templates
|
||||
info.update({"ok": False, "error": message, "status": status, "title": title})
|
||||
headers = {}
|
||||
if self.ds.cors:
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
if scope["path"].split("?")[0].endswith(".json"):
|
||||
await asgi_send_json(send, info, status=status, headers=headers)
|
||||
else:
|
||||
template = self.ds.jinja_env.select_template(templates)
|
||||
await asgi_send_html(
|
||||
send, template.render(info), status=status, headers=headers
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import uvicorn
|
||||
import click
|
||||
from click import formatting
|
||||
from click_default_group import DefaultGroup
|
||||
|
|
@ -354,4 +355,4 @@ def serve(
|
|||
asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks())
|
||||
|
||||
# Start the server
|
||||
ds.app().run(host=host, port=port, debug=debug)
|
||||
uvicorn.run(ds.app(), host=host, port=port, log_level="info")
|
||||
|
|
|
|||
|
|
@ -88,5 +88,5 @@ def json_renderer(args, data, view_name):
|
|||
content_type = "text/plain"
|
||||
else:
|
||||
body = json.dumps(data, cls=CustomJSONEncoder)
|
||||
content_type = "application/json"
|
||||
content_type = "application/json; charset=utf-8"
|
||||
return {"body": body, "status_code": status_code, "content_type": content_type}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
import json
|
||||
import traceback
|
||||
|
||||
tracers = {}
|
||||
|
|
@ -32,15 +33,15 @@ def trace(type, **kwargs):
|
|||
start = time.time()
|
||||
yield
|
||||
end = time.time()
|
||||
trace = {
|
||||
trace_info = {
|
||||
"type": type,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"duration_ms": (end - start) * 1000,
|
||||
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
|
||||
}
|
||||
trace.update(kwargs)
|
||||
tracer.append(trace)
|
||||
trace_info.update(kwargs)
|
||||
tracer.append(trace_info)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -53,3 +54,77 @@ def capture_traces(tracer):
|
|||
tracers[task_id] = tracer
|
||||
yield
|
||||
del tracers[task_id]
|
||||
|
||||
|
||||
class AsgiTracer:
|
||||
# If the body is larger than this we don't attempt to append the trace
|
||||
max_body_bytes = 1024 * 256 # 256 KB
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if b"_trace=1" not in scope.get("query_string", b"").split(b"&"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
trace_start = time.time()
|
||||
traces = []
|
||||
|
||||
accumulated_body = b""
|
||||
size_limit_exceeded = False
|
||||
response_headers = []
|
||||
|
||||
async def wrapped_send(message):
|
||||
nonlocal accumulated_body, size_limit_exceeded, response_headers
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = message["headers"]
|
||||
await send(message)
|
||||
return
|
||||
|
||||
if message["type"] != "http.response.body" or size_limit_exceeded:
|
||||
await send(message)
|
||||
return
|
||||
|
||||
# Accumulate body until the end or until size is exceeded
|
||||
accumulated_body += message["body"]
|
||||
if len(accumulated_body) > self.max_body_bytes:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": accumulated_body,
|
||||
"more_body": True,
|
||||
}
|
||||
)
|
||||
size_limit_exceeded = True
|
||||
return
|
||||
|
||||
if not message.get("more_body"):
|
||||
# We have all the body - modify it and send the result
|
||||
# TODO: What to do about Content-Type or other cases?
|
||||
trace_info = {
|
||||
"request_duration_ms": 1000 * (time.time() - trace_start),
|
||||
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
|
||||
"num_traces": len(traces),
|
||||
"traces": traces,
|
||||
}
|
||||
try:
|
||||
content_type = [
|
||||
v.decode("utf8")
|
||||
for k, v in response_headers
|
||||
if k.lower() == b"content-type"
|
||||
][0]
|
||||
except IndexError:
|
||||
content_type = ""
|
||||
if "text/html" in content_type and b"</body>" in accumulated_body:
|
||||
extra = json.dumps(trace_info, indent=2)
|
||||
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
|
||||
accumulated_body = accumulated_body.replace(b"</body>", extra_html)
|
||||
elif "json" in content_type and accumulated_body.startswith(b"{"):
|
||||
data = json.loads(accumulated_body.decode("utf8"))
|
||||
if "_trace" not in data:
|
||||
data["_trace"] = trace_info
|
||||
accumulated_body = json.dumps(data).encode("utf8")
|
||||
await send({"type": "http.response.body", "body": accumulated_body})
|
||||
|
||||
with capture_traces(traces):
|
||||
await self.app(scope, receive, wrapped_send)
|
||||
|
|
|
|||
|
|
@ -697,13 +697,13 @@ class LimitedWriter:
|
|||
self.limit_bytes = limit_mb * 1024 * 1024
|
||||
self.bytes_count = 0
|
||||
|
||||
def write(self, bytes):
|
||||
async def write(self, bytes):
|
||||
self.bytes_count += len(bytes)
|
||||
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
|
||||
raise WriteLimitExceeded(
|
||||
"CSV contains more than {} bytes".format(self.limit_bytes)
|
||||
)
|
||||
self.writer.write(bytes)
|
||||
await self.writer.write(bytes)
|
||||
|
||||
|
||||
_infinities = {float("inf"), float("-inf")}
|
||||
|
|
@ -741,3 +741,16 @@ def format_bytes(bytes):
|
|||
return "{} {}".format(int(current), unit)
|
||||
else:
|
||||
return "{:.1f} {}".format(current, unit)
|
||||
|
||||
|
||||
class RequestParameters(dict):
|
||||
def get(self, name, default=None):
|
||||
"Return first value in the list, if available"
|
||||
try:
|
||||
return super().get(name)[0]
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def getlist(self, name, default=None):
|
||||
"Return full list"
|
||||
return super().get(name, default)
|
||||
377
datasette/utils/asgi.py
Normal file
377
datasette/utils/asgi.py
Normal file
|
|
@ -0,0 +1,377 @@
|
|||
import json
|
||||
from datasette.utils import RequestParameters
|
||||
from mimetypes import guess_type
|
||||
from urllib.parse import parse_qs, urlunparse
|
||||
from pathlib import Path
|
||||
from html import escape
|
||||
import re
|
||||
import aiofiles
|
||||
|
||||
|
||||
class NotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(self, scope):
|
||||
self.scope = scope
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
return self.scope["method"]
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return urlunparse(
|
||||
(self.scheme, self.host, self.path, None, self.query_string, None)
|
||||
)
|
||||
|
||||
@property
|
||||
def scheme(self):
|
||||
return self.scope.get("scheme") or "http"
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return dict(
|
||||
[
|
||||
(k.decode("latin-1").lower(), v.decode("latin-1"))
|
||||
for k, v in self.scope.get("headers") or []
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.headers.get("host") or "localhost"
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return (
|
||||
self.scope.get("raw_path", self.scope["path"].encode("latin-1"))
|
||||
).decode("latin-1")
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
return (self.scope.get("query_string") or b"").decode("latin-1")
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return RequestParameters(parse_qs(qs=self.query_string))
|
||||
|
||||
@property
|
||||
def raw_args(self):
|
||||
return {key: value[0] for key, value in self.args.items()}
|
||||
|
||||
@classmethod
|
||||
def fake(cls, path_with_query_string, method="GET", scheme="http"):
|
||||
"Useful for constructing Request objects for tests"
|
||||
path, _, query_string = path_with_query_string.partition("?")
|
||||
scope = {
|
||||
"http_version": "1.1",
|
||||
"method": method,
|
||||
"path": path,
|
||||
"raw_path": path.encode("latin-1"),
|
||||
"query_string": query_string.encode("latin-1"),
|
||||
"scheme": scheme,
|
||||
"type": "http",
|
||||
}
|
||||
return cls(scope)
|
||||
|
||||
|
||||
class AsgiRouter:
|
||||
def __init__(self, routes=None):
|
||||
routes = routes or []
|
||||
self.routes = [
|
||||
# Compile any strings to regular expressions
|
||||
((re.compile(pattern) if isinstance(pattern, str) else pattern), view)
|
||||
for pattern, view in routes
|
||||
]
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves
|
||||
path = scope["raw_path"].decode("ascii")
|
||||
for regex, view in self.routes:
|
||||
match = regex.match(path)
|
||||
if match is not None:
|
||||
new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
|
||||
try:
|
||||
return await view(new_scope, receive, send)
|
||||
except Exception as exception:
|
||||
return await self.handle_500(scope, receive, send, exception)
|
||||
return await self.handle_404(scope, receive, send)
|
||||
|
||||
async def handle_404(self, scope, receive, send):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [[b"content-type", b"text/html"]],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"<h1>404</h1>"})
|
||||
|
||||
async def handle_500(self, scope, receive, send, exception):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [[b"content-type", b"text/html"]],
|
||||
}
|
||||
)
|
||||
html = "<h1>500</h1><pre{}></pre>".format(escape(repr(exception)))
|
||||
await send({"type": "http.response.body", "body": html.encode("latin-1")})
|
||||
|
||||
|
||||
class AsgiLifespan:
|
||||
def __init__(self, app, on_startup=None, on_shutdown=None):
|
||||
self.app = app
|
||||
on_startup = on_startup or []
|
||||
on_shutdown = on_shutdown or []
|
||||
if not isinstance(on_startup or [], list):
|
||||
on_startup = [on_startup]
|
||||
if not isinstance(on_shutdown or [], list):
|
||||
on_shutdown = [on_shutdown]
|
||||
self.on_startup = on_startup
|
||||
self.on_shutdown = on_shutdown
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "lifespan":
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
for fn in self.on_startup:
|
||||
await fn()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
for fn in self.on_shutdown:
|
||||
await fn()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class AsgiView:
|
||||
def dispatch_request(self, request, *args, **kwargs):
|
||||
handler = getattr(self, request.method.lower(), None)
|
||||
return handler(request, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def as_asgi(cls, *class_args, **class_kwargs):
|
||||
async def view(scope, receive, send):
|
||||
# Uses scope to create a request object, then dispatches that to
|
||||
# self.get(...) or self.options(...) along with keyword arguments
|
||||
# that were already tucked into scope["url_route"]["kwargs"] by
|
||||
# the router, similar to how Django Channels works:
|
||||
# https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter
|
||||
request = Request(scope)
|
||||
self = view.view_class(*class_args, **class_kwargs)
|
||||
response = await self.dispatch_request(
|
||||
request, **scope["url_route"]["kwargs"]
|
||||
)
|
||||
await response.asgi_send(send)
|
||||
|
||||
view.view_class = cls
|
||||
view.__doc__ = cls.__doc__
|
||||
view.__module__ = cls.__module__
|
||||
view.__name__ = cls.__name__
|
||||
return view
|
||||
|
||||
|
||||
class AsgiStream:
|
||||
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
|
||||
self.stream_fn = stream_fn
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self.content_type = content_type
|
||||
|
||||
async def asgi_send(self, send):
|
||||
# Remove any existing content-type header
|
||||
headers = dict(
|
||||
[(k, v) for k, v in self.headers.items() if k.lower() != "content-type"]
|
||||
)
|
||||
headers["content-type"] = self.content_type
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status,
|
||||
"headers": [
|
||||
[key.encode("utf-8"), value.encode("utf-8")]
|
||||
for key, value in headers.items()
|
||||
],
|
||||
}
|
||||
)
|
||||
w = AsgiWriter(send)
|
||||
await self.stream_fn(w)
|
||||
await send({"type": "http.response.body", "body": b""})
|
||||
|
||||
|
||||
class AsgiWriter:
|
||||
def __init__(self, send):
|
||||
self.send = send
|
||||
|
||||
async def write(self, chunk):
|
||||
await self.send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": chunk.encode("latin-1"),
|
||||
"more_body": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def asgi_send_json(send, info, status=200, headers=None):
|
||||
headers = headers or {}
|
||||
await asgi_send(
|
||||
send,
|
||||
json.dumps(info),
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
|
||||
|
||||
async def asgi_send_html(send, html, status=200, headers=None):
|
||||
headers = headers or {}
|
||||
await asgi_send(
|
||||
send, html, status=status, headers=headers, content_type="text/html"
|
||||
)
|
||||
|
||||
|
||||
async def asgi_send_redirect(send, location, status=302):
|
||||
await asgi_send(
|
||||
send,
|
||||
"",
|
||||
status=status,
|
||||
headers={"Location": location},
|
||||
content_type="text/html",
|
||||
)
|
||||
|
||||
|
||||
async def asgi_send(send, content, status, headers=None, content_type="text/plain"):
|
||||
await asgi_start(send, status, headers, content_type)
|
||||
await send({"type": "http.response.body", "body": content.encode("latin-1")})
|
||||
|
||||
|
||||
async def asgi_start(send, status, headers=None, content_type="text/plain"):
|
||||
headers = headers or {}
|
||||
# Remove any existing content-type header
|
||||
headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"])
|
||||
headers["content-type"] = content_type
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [
|
||||
[key.encode("latin1"), value.encode("latin1")]
|
||||
for key, value in headers.items()
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def asgi_send_file(
|
||||
send, filepath, filename=None, content_type=None, chunk_size=4096
|
||||
):
|
||||
headers = {}
|
||||
if filename:
|
||||
headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename)
|
||||
first = True
|
||||
async with aiofiles.open(str(filepath), mode="rb") as fp:
|
||||
if first:
|
||||
await asgi_start(
|
||||
send,
|
||||
200,
|
||||
headers,
|
||||
content_type or guess_type(str(filepath))[0] or "text/plain",
|
||||
)
|
||||
first = False
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await fp.read(chunk_size)
|
||||
more_body = len(chunk) == chunk_size
|
||||
await send(
|
||||
{"type": "http.response.body", "body": chunk, "more_body": more_body}
|
||||
)
|
||||
|
||||
|
||||
def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None):
|
||||
async def inner_static(scope, receive, send):
|
||||
path = scope["url_route"]["kwargs"]["path"]
|
||||
full_path = (Path(root_path) / path).absolute()
|
||||
# Ensure full_path is within root_path to avoid weird "../" tricks
|
||||
try:
|
||||
full_path.relative_to(root_path)
|
||||
except ValueError:
|
||||
await asgi_send_html(send, "404", 404)
|
||||
return
|
||||
first = True
|
||||
try:
|
||||
await asgi_send_file(send, full_path, chunk_size=chunk_size)
|
||||
except FileNotFoundError:
|
||||
await asgi_send_html(send, "404", 404)
|
||||
return
|
||||
|
||||
return inner_static
|
||||
|
||||
|
||||
class Response:
|
||||
def __init__(self, body=None, status=200, headers=None, content_type="text/plain"):
|
||||
self.body = body
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self.content_type = content_type
|
||||
|
||||
async def asgi_send(self, send):
|
||||
headers = {}
|
||||
headers.update(self.headers)
|
||||
headers["content-type"] = self.content_type
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status,
|
||||
"headers": [
|
||||
[key.encode("utf-8"), value.encode("utf-8")]
|
||||
for key, value in headers.items()
|
||||
],
|
||||
}
|
||||
)
|
||||
body = self.body
|
||||
if not isinstance(body, bytes):
|
||||
body = body.encode("utf-8")
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
|
||||
@classmethod
|
||||
def html(cls, body, status=200, headers=None):
|
||||
return cls(
|
||||
body,
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type="text/html; charset=utf-8",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def text(cls, body, status=200, headers=None):
|
||||
return cls(
|
||||
body,
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type="text/plain; charset=utf-8",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def redirect(cls, path, status=302, headers=None):
|
||||
headers = headers or {}
|
||||
headers["Location"] = path
|
||||
return cls("", status=status, headers=headers)
|
||||
|
||||
|
||||
class AsgiFileDownload:
|
||||
def __init__(
|
||||
self, filepath, filename=None, content_type="application/octet-stream"
|
||||
):
|
||||
self.filepath = filepath
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def asgi_send(self, send):
|
||||
return await asgi_send_file(send, self.filepath, content_type=self.content_type)
|
||||
|
|
@ -7,9 +7,8 @@ import urllib
|
|||
|
||||
import jinja2
|
||||
import pint
|
||||
from sanic import response
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.views import HTTPMethodView
|
||||
|
||||
from html import escape
|
||||
|
||||
from datasette import __version__
|
||||
from datasette.plugins import pm
|
||||
|
|
@ -26,6 +25,14 @@ from datasette.utils import (
|
|||
sqlite3,
|
||||
to_css_class,
|
||||
)
|
||||
from datasette.utils.asgi import (
|
||||
AsgiStream,
|
||||
AsgiWriter,
|
||||
AsgiRouter,
|
||||
AsgiView,
|
||||
NotFound,
|
||||
Response,
|
||||
)
|
||||
|
||||
ureg = pint.UnitRegistry()
|
||||
|
||||
|
|
@ -49,7 +56,14 @@ class DatasetteError(Exception):
|
|||
self.messagge_is_html = messagge_is_html
|
||||
|
||||
|
||||
class BaseView(HTTPMethodView):
|
||||
class BaseView(AsgiView):
|
||||
ds = None
|
||||
|
||||
async def head(self, *args, **kwargs):
|
||||
response = await self.get(*args, **kwargs)
|
||||
response.body = b""
|
||||
return response
|
||||
|
||||
def _asset_urls(self, key, template, context):
|
||||
# Flatten list-of-lists from plugins:
|
||||
seen_urls = set()
|
||||
|
|
@ -104,7 +118,7 @@ class BaseView(HTTPMethodView):
|
|||
datasette=self.ds,
|
||||
):
|
||||
body_scripts.append(jinja2.Markup(script))
|
||||
return response.html(
|
||||
return Response.html(
|
||||
template.render(
|
||||
{
|
||||
**context,
|
||||
|
|
@ -136,7 +150,7 @@ class DataView(BaseView):
|
|||
self.ds = datasette
|
||||
|
||||
def options(self, request, *args, **kwargs):
|
||||
r = response.text("ok")
|
||||
r = Response.text("ok")
|
||||
if self.ds.cors:
|
||||
r.headers["Access-Control-Allow-Origin"] = "*"
|
||||
return r
|
||||
|
|
@ -146,7 +160,7 @@ class DataView(BaseView):
|
|||
path = "{}?{}".format(path, request.query_string)
|
||||
if remove_args:
|
||||
path = path_with_removed_args(request, remove_args, path=path)
|
||||
r = response.redirect(path)
|
||||
r = Response.redirect(path)
|
||||
r.headers["Link"] = "<{}>; rel=preload".format(path)
|
||||
if self.ds.cors:
|
||||
r.headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
|
@ -195,17 +209,17 @@ class DataView(BaseView):
|
|||
kwargs["table"] = table
|
||||
if _format:
|
||||
kwargs["as_format"] = ".{}".format(_format)
|
||||
elif "table" in kwargs:
|
||||
elif kwargs.get("table"):
|
||||
kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
|
||||
|
||||
should_redirect = "/{}-{}".format(name, expected)
|
||||
if "table" in kwargs:
|
||||
if kwargs.get("table"):
|
||||
should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
|
||||
if "pk_path" in kwargs:
|
||||
if kwargs.get("pk_path"):
|
||||
should_redirect += "/" + kwargs["pk_path"]
|
||||
if "as_format" in kwargs:
|
||||
if kwargs.get("as_format"):
|
||||
should_redirect += kwargs["as_format"]
|
||||
if "as_db" in kwargs:
|
||||
if kwargs.get("as_db"):
|
||||
should_redirect += kwargs["as_db"]
|
||||
|
||||
if (
|
||||
|
|
@ -246,7 +260,7 @@ class DataView(BaseView):
|
|||
response_or_template_contexts = await self.data(
|
||||
request, database, hash, **kwargs
|
||||
)
|
||||
if isinstance(response_or_template_contexts, response.HTTPResponse):
|
||||
if isinstance(response_or_template_contexts, Response):
|
||||
return response_or_template_contexts
|
||||
else:
|
||||
data, _, _ = response_or_template_contexts
|
||||
|
|
@ -282,13 +296,13 @@ class DataView(BaseView):
|
|||
if not first:
|
||||
data, _, _ = await self.data(request, database, hash, **kwargs)
|
||||
if first:
|
||||
writer.writerow(headings)
|
||||
await writer.writerow(headings)
|
||||
first = False
|
||||
next = data.get("next")
|
||||
for row in data["rows"]:
|
||||
if not expanded_columns:
|
||||
# Simple path
|
||||
writer.writerow(row)
|
||||
await writer.writerow(row)
|
||||
else:
|
||||
# Look for {"value": "label": } dicts and expand
|
||||
new_row = []
|
||||
|
|
@ -298,10 +312,10 @@ class DataView(BaseView):
|
|||
new_row.append(cell["label"])
|
||||
else:
|
||||
new_row.append(cell)
|
||||
writer.writerow(new_row)
|
||||
await writer.writerow(new_row)
|
||||
except Exception as e:
|
||||
print("caught this", e)
|
||||
r.write(str(e))
|
||||
await r.write(str(e))
|
||||
return
|
||||
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
|
|
@ -315,7 +329,7 @@ class DataView(BaseView):
|
|||
)
|
||||
headers["Content-Disposition"] = disposition
|
||||
|
||||
return response.stream(stream_fn, headers=headers, content_type=content_type)
|
||||
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
|
||||
|
||||
async def get_format(self, request, database, args):
|
||||
""" Determine the format of the response from the request, from URL
|
||||
|
|
@ -363,7 +377,7 @@ class DataView(BaseView):
|
|||
response_or_template_contexts = await self.data(
|
||||
request, database, hash, **kwargs
|
||||
)
|
||||
if isinstance(response_or_template_contexts, response.HTTPResponse):
|
||||
if isinstance(response_or_template_contexts, Response):
|
||||
return response_or_template_contexts
|
||||
|
||||
else:
|
||||
|
|
@ -414,17 +428,11 @@ class DataView(BaseView):
|
|||
if result is None:
|
||||
raise NotFound("No data")
|
||||
|
||||
response_args = {
|
||||
"content_type": result.get("content_type", "text/plain"),
|
||||
"status": result.get("status_code", 200),
|
||||
}
|
||||
|
||||
if type(result.get("body")) == bytes:
|
||||
response_args["body_bytes"] = result.get("body")
|
||||
else:
|
||||
response_args["body"] = result.get("body")
|
||||
|
||||
r = response.HTTPResponse(**response_args)
|
||||
r = Response(
|
||||
body=result.get("body"),
|
||||
status=result.get("status_code", 200),
|
||||
content_type=result.get("content_type", "text/plain"),
|
||||
)
|
||||
else:
|
||||
extras = {}
|
||||
if callable(extra_template_data):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import os
|
||||
|
||||
from sanic import response
|
||||
|
||||
from datasette.utils import to_css_class, validate_sql_select
|
||||
from datasette.utils.asgi import AsgiFileDownload
|
||||
|
||||
from .base import DataView, DatasetteError
|
||||
from .base import DatasetteError, DataView
|
||||
|
||||
|
||||
class DatabaseView(DataView):
|
||||
|
|
@ -79,8 +78,8 @@ class DatabaseDownload(DataView):
|
|||
if not db.path:
|
||||
raise DatasetteError("Cannot download database", status=404)
|
||||
filepath = db.path
|
||||
return await response.file_stream(
|
||||
return AsgiFileDownload(
|
||||
filepath,
|
||||
filename=os.path.basename(filepath),
|
||||
mime_type="application/octet-stream",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import hashlib
|
||||
import json
|
||||
|
||||
from sanic import response
|
||||
|
||||
from datasette.utils import CustomJSONEncoder
|
||||
from datasette.utils.asgi import Response
|
||||
from datasette.version import __version__
|
||||
|
||||
from .base import BaseView
|
||||
|
|
@ -104,9 +103,9 @@ class IndexView(BaseView):
|
|||
headers = {}
|
||||
if self.ds.cors:
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
return response.HTTPResponse(
|
||||
return Response(
|
||||
json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder),
|
||||
content_type="application/json",
|
||||
content_type="application/json; charset=utf-8",
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from sanic import response
|
||||
from datasette.utils.asgi import Response
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
|
|
@ -17,8 +17,10 @@ class JsonDataView(BaseView):
|
|||
headers = {}
|
||||
if self.ds.cors:
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
return response.HTTPResponse(
|
||||
json.dumps(data), content_type="application/json", headers=headers
|
||||
return Response(
|
||||
json.dumps(data),
|
||||
content_type="application/json; charset=utf-8",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ import itertools
|
|||
import json
|
||||
|
||||
import jinja2
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.request import RequestParameters
|
||||
|
||||
from datasette.plugins import pm
|
||||
from datasette.utils import (
|
||||
CustomRow,
|
||||
QueryInterrupted,
|
||||
RequestParameters,
|
||||
append_querystring,
|
||||
compound_keys_after_sql,
|
||||
escape_sqlite,
|
||||
|
|
@ -24,6 +23,7 @@ from datasette.utils import (
|
|||
urlsafe_components,
|
||||
value_as_boolean,
|
||||
)
|
||||
from datasette.utils.asgi import NotFound
|
||||
from datasette.filters import Filters
|
||||
from .base import DataView, DatasetteError, ureg
|
||||
|
||||
|
|
@ -219,8 +219,7 @@ class TableView(RowTableShared):
|
|||
if is_view:
|
||||
order_by = ""
|
||||
|
||||
# We roll our own query_string decoder because by default Sanic
|
||||
# drops anything with an empty value e.g. ?name__exact=
|
||||
# Ensure we don't drop anything with an empty value e.g. ?name__exact=
|
||||
args = RequestParameters(
|
||||
urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,5 @@ filterwarnings=
|
|||
ignore:Using or importing the ABCs::jinja2
|
||||
# https://bugs.launchpad.net/beautifulsoup/+bug/1778909
|
||||
ignore:Using or importing the ABCs::bs4.element
|
||||
# Sanic verify_ssl=True
|
||||
ignore:verify_ssl is deprecated::sanic
|
||||
# Python 3.7 PendingDeprecationWarning: Task.current_task()
|
||||
ignore:.*current_task.*:PendingDeprecationWarning
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -37,17 +37,18 @@ setup(
|
|||
author="Simon Willison",
|
||||
license="Apache License, Version 2.0",
|
||||
url="https://github.com/simonw/datasette",
|
||||
packages=find_packages(exclude='tests'),
|
||||
packages=find_packages(exclude="tests"),
|
||||
package_data={"datasette": ["templates/*.html"]},
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"click>=6.7",
|
||||
"click-default-group==1.2",
|
||||
"Sanic==0.7.0",
|
||||
"Jinja2==2.10.1",
|
||||
"hupper==1.0",
|
||||
"pint==0.8.1",
|
||||
"pluggy>=0.12.0",
|
||||
"uvicorn>=0.8.1",
|
||||
"aiofiles==0.4.0",
|
||||
],
|
||||
entry_points="""
|
||||
[console_scripts]
|
||||
|
|
@ -60,6 +61,7 @@ setup(
|
|||
"pytest-asyncio==0.10.0",
|
||||
"aiohttp==3.5.3",
|
||||
"beautifulsoup4==4.6.1",
|
||||
"asgiref==3.1.2",
|
||||
]
|
||||
+ maybe_black
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from datasette.app import Datasette
|
||||
from datasette.utils import sqlite3
|
||||
from asgiref.testing import ApplicationCommunicator
|
||||
from asgiref.sync import async_to_sync
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
|
|
@ -10,16 +12,82 @@ import sys
|
|||
import string
|
||||
import tempfile
|
||||
import time
|
||||
from urllib.parse import unquote
|
||||
|
||||
|
||||
class TestResponse:
|
||||
def __init__(self, status, headers, body):
|
||||
self.status = status
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self.body.decode("utf8")
|
||||
|
||||
|
||||
class TestClient:
|
||||
def __init__(self, sanic_test_client):
|
||||
self.sanic_test_client = sanic_test_client
|
||||
max_redirects = 5
|
||||
|
||||
def get(self, path, allow_redirects=True):
|
||||
return self.sanic_test_client.get(
|
||||
path, allow_redirects=allow_redirects, gather_request=False
|
||||
def __init__(self, asgi_app):
|
||||
self.asgi_app = asgi_app
|
||||
|
||||
@async_to_sync
|
||||
async def get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
|
||||
return await self._get(path, allow_redirects, redirect_count, method)
|
||||
|
||||
async def _get(self, path, allow_redirects=True, redirect_count=0, method="GET"):
|
||||
query_string = b""
|
||||
if "?" in path:
|
||||
path, _, query_string = path.partition("?")
|
||||
query_string = query_string.encode("utf8")
|
||||
instance = ApplicationCommunicator(
|
||||
self.asgi_app,
|
||||
{
|
||||
"type": "http",
|
||||
"http_version": "1.0",
|
||||
"method": method,
|
||||
"path": unquote(path),
|
||||
"raw_path": path.encode("ascii"),
|
||||
"query_string": query_string,
|
||||
"headers": [[b"host", b"localhost"]],
|
||||
},
|
||||
)
|
||||
await instance.send_input({"type": "http.request"})
|
||||
# First message back should be response.start with headers and status
|
||||
messages = []
|
||||
start = await instance.receive_output(2)
|
||||
messages.append(start)
|
||||
assert start["type"] == "http.response.start"
|
||||
headers = dict(
|
||||
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
|
||||
)
|
||||
status = start["status"]
|
||||
# Now loop until we run out of response.body
|
||||
body = b""
|
||||
while True:
|
||||
message = await instance.receive_output(2)
|
||||
messages.append(message)
|
||||
assert message["type"] == "http.response.body"
|
||||
body += message["body"]
|
||||
if not message.get("more_body"):
|
||||
break
|
||||
response = TestResponse(status, headers, body)
|
||||
if allow_redirects and response.status in (301, 302):
|
||||
assert (
|
||||
redirect_count < self.max_redirects
|
||||
), "Redirected {} times, max_redirects={}".format(
|
||||
redirect_count, self.max_redirects
|
||||
)
|
||||
location = response.headers["Location"]
|
||||
return await self._get(
|
||||
location, allow_redirects=True, redirect_count=redirect_count + 1
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def make_app_client(
|
||||
|
|
@ -32,6 +100,7 @@ def make_app_client(
|
|||
is_immutable=False,
|
||||
extra_databases=None,
|
||||
inspect_data=None,
|
||||
static_mounts=None,
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = os.path.join(tmpdir, filename)
|
||||
|
|
@ -73,9 +142,10 @@ def make_app_client(
|
|||
plugins_dir=plugins_dir,
|
||||
config=config,
|
||||
inspect_data=inspect_data,
|
||||
static_mounts=static_mounts,
|
||||
)
|
||||
ds.sqlite_functions.append(("sleep", 1, lambda n: time.sleep(float(n))))
|
||||
client = TestClient(ds.app().test_client)
|
||||
client = TestClient(ds.app())
|
||||
client.ds = ds
|
||||
yield client
|
||||
|
||||
|
|
@ -88,7 +158,7 @@ def app_client():
|
|||
@pytest.fixture(scope="session")
|
||||
def app_client_no_files():
|
||||
ds = Datasette([])
|
||||
client = TestClient(ds.app().test_client)
|
||||
client = TestClient(ds.app())
|
||||
client.ds = ds
|
||||
yield client
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import urllib
|
|||
def test_homepage(app_client):
|
||||
response = app_client.get("/.json")
|
||||
assert response.status == 200
|
||||
assert "application/json; charset=utf-8" == response.headers["content-type"]
|
||||
assert response.json.keys() == {"fixtures": 0}.keys()
|
||||
d = response.json["fixtures"]
|
||||
assert d["name"] == "fixtures"
|
||||
|
|
@ -771,8 +772,8 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag
|
|||
fetched.extend(response.json["rows"])
|
||||
path = response.json["next_url"]
|
||||
if path:
|
||||
assert response.json["next"]
|
||||
assert urllib.parse.urlencode({"_next": response.json["next"]}) in path
|
||||
path = path.replace("http://localhost", "")
|
||||
assert count < 30, "Possible infinite loop detected"
|
||||
|
||||
assert expected_rows == len(fetched)
|
||||
|
|
@ -812,6 +813,8 @@ def test_paginate_compound_keys(app_client):
|
|||
response = app_client.get(path)
|
||||
fetched.extend(response.json["rows"])
|
||||
path = response.json["next_url"]
|
||||
if path:
|
||||
path = path.replace("http://localhost", "")
|
||||
assert page < 100
|
||||
assert 1001 == len(fetched)
|
||||
assert 21 == page
|
||||
|
|
@ -833,6 +836,8 @@ def test_paginate_compound_keys_with_extra_filters(app_client):
|
|||
response = app_client.get(path)
|
||||
fetched.extend(response.json["rows"])
|
||||
path = response.json["next_url"]
|
||||
if path:
|
||||
path = path.replace("http://localhost", "")
|
||||
assert 2 == page
|
||||
expected = [r[3] for r in generate_compound_rows(1001) if "d" in r[3]]
|
||||
assert expected == [f["content"] for f in fetched]
|
||||
|
|
@ -881,6 +886,8 @@ def test_sortable(app_client, query_string, sort_key, human_description_en):
|
|||
assert human_description_en == response.json["human_description_en"]
|
||||
fetched.extend(response.json["rows"])
|
||||
path = response.json["next_url"]
|
||||
if path:
|
||||
path = path.replace("http://localhost", "")
|
||||
assert 5 == page
|
||||
expected = list(generate_sortable_rows(201))
|
||||
expected.sort(key=sort_key)
|
||||
|
|
@ -1191,6 +1198,7 @@ def test_plugins_json(app_client):
|
|||
def test_versions_json(app_client):
|
||||
response = app_client.get("/-/versions.json")
|
||||
assert "python" in response.json
|
||||
assert "3.0" == response.json.get("asgi")
|
||||
assert "version" in response.json["python"]
|
||||
assert "full" in response.json["python"]
|
||||
assert "datasette" in response.json
|
||||
|
|
@ -1236,6 +1244,8 @@ def test_page_size_matching_max_returned_rows(
|
|||
fetched.extend(response.json["rows"])
|
||||
assert len(response.json["rows"]) in (1, 50)
|
||||
path = response.json["next_url"]
|
||||
if path:
|
||||
path = path.replace("http://localhost", "")
|
||||
assert 201 == len(fetched)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def test_table_csv(app_client):
|
|||
response = app_client.get("/fixtures/simple_primary_key.csv")
|
||||
assert response.status == 200
|
||||
assert not response.headers.get("Access-Control-Allow-Origin")
|
||||
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
|
||||
assert "text/plain; charset=utf-8" == response.headers["content-type"]
|
||||
assert EXPECTED_TABLE_CSV == response.text
|
||||
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors):
|
|||
def test_table_csv_with_labels(app_client):
|
||||
response = app_client.get("/fixtures/facetable.csv?_labels=1")
|
||||
assert response.status == 200
|
||||
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
|
||||
assert "text/plain; charset=utf-8" == response.headers["content-type"]
|
||||
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
|
||||
|
||||
|
||||
|
|
@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client):
|
|||
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
|
||||
)
|
||||
assert response.status == 200
|
||||
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
|
||||
assert "text/plain; charset=utf-8" == response.headers["content-type"]
|
||||
assert EXPECTED_CUSTOM_CSV == response.text
|
||||
|
||||
|
||||
def test_table_csv_download(app_client):
|
||||
response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
|
||||
assert response.status == 200
|
||||
assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
|
||||
assert "text/csv; charset=utf-8" == response.headers["content-type"]
|
||||
expected_disposition = 'attachment; filename="simple_primary_key.csv"'
|
||||
assert expected_disposition == response.headers["Content-Disposition"]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from .fixtures import ( # noqa
|
|||
METADATA,
|
||||
)
|
||||
import json
|
||||
import pathlib
|
||||
import pytest
|
||||
import re
|
||||
import urllib.parse
|
||||
|
|
@ -16,6 +17,7 @@ import urllib.parse
|
|||
def test_homepage(app_client_two_attached_databases):
|
||||
response = app_client_two_attached_databases.get("/")
|
||||
assert response.status == 200
|
||||
assert "text/html; charset=utf-8" == response.headers["content-type"]
|
||||
soup = Soup(response.body, "html.parser")
|
||||
assert "Datasette Fixtures" == soup.find("h1").text
|
||||
assert (
|
||||
|
|
@ -44,6 +46,29 @@ def test_homepage(app_client_two_attached_databases):
|
|||
] == table_links
|
||||
|
||||
|
||||
def test_http_head(app_client):
|
||||
response = app_client.get("/", method="HEAD")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_static(app_client):
|
||||
response = app_client.get("/-/static/app2.css")
|
||||
assert response.status == 404
|
||||
response = app_client.get("/-/static/app.css")
|
||||
assert response.status == 200
|
||||
assert "text/css" == response.headers["content-type"]
|
||||
|
||||
|
||||
def test_static_mounts():
|
||||
for client in make_app_client(
|
||||
static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))]
|
||||
):
|
||||
response = client.get("/custom-static/test_html.py")
|
||||
assert response.status == 200
|
||||
response = client.get("/custom-static/not_exists.py")
|
||||
assert response.status == 404
|
||||
|
||||
|
||||
def test_memory_database_page():
|
||||
for client in make_app_client(memory=True):
|
||||
response = client.get("/:memory:")
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ Tests for various datasette helper functions.
|
|||
"""
|
||||
|
||||
from datasette import utils
|
||||
from datasette.utils.asgi import Request
|
||||
from datasette.filters import Filters
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from sanic.request import Request
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
|
@ -53,7 +53,7 @@ def test_urlsafe_components(path, expected):
|
|||
],
|
||||
)
|
||||
def test_path_with_added_args(path, added_args, expected):
|
||||
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
|
||||
request = Request.fake(path)
|
||||
actual = utils.path_with_added_args(request, added_args)
|
||||
assert expected == actual
|
||||
|
||||
|
|
@ -67,11 +67,11 @@ def test_path_with_added_args(path, added_args, expected):
|
|||
],
|
||||
)
|
||||
def test_path_with_removed_args(path, args, expected):
|
||||
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
|
||||
request = Request.fake(path)
|
||||
actual = utils.path_with_removed_args(request, args)
|
||||
assert expected == actual
|
||||
# Run the test again but this time use the path= argument
|
||||
request = Request("/".encode("utf8"), {}, "1.1", "GET", None)
|
||||
request = Request.fake("/")
|
||||
actual = utils.path_with_removed_args(request, args, path=path)
|
||||
assert expected == actual
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def test_path_with_removed_args(path, args, expected):
|
|||
],
|
||||
)
|
||||
def test_path_with_replaced_args(path, args, expected):
|
||||
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
|
||||
request = Request.fake(path)
|
||||
actual = utils.path_with_replaced_args(request, args)
|
||||
assert expected == actual
|
||||
|
||||
|
|
@ -363,7 +363,7 @@ def test_table_columns():
|
|||
],
|
||||
)
|
||||
def test_path_with_format(path, format, extra_qs, expected):
|
||||
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
|
||||
request = Request.fake(path)
|
||||
actual = utils.path_with_format(request, format, extra_qs)
|
||||
assert expected == actual
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue