New plugin hook: register_output_renderer hook (#441)

Thanks @russss!

* Add register_output_renderer hook

This changeset refactors out the JSON renderer and then adds a hook and
dispatcher system to allow custom output renderers to be registered.

The CSV output renderer is untouched because supporting streaming
renderers through this system would be significantly more complex, and
probably not worthwhile.

We can't simply allow hooks to be called at request time because we need
a list of supported file extensions when the request is being routed in
order to resolve ambiguous database/table names. So, renderers need to
be registered at startup.

I've tried to make this API independent of Sanic's request/response
objects so that this can remain stable during the switch to ASGI. I'm
using dictionaries to keep it simple and to make adding additional
options in the future easy.

Fixes #440
This commit is contained in:
Russ Garrett 2019-05-02 00:01:56 +01:00 committed by Simon Willison
commit cf406c0754
13 changed files with 269 additions and 149 deletions

View file

@ -1,7 +1,6 @@
import asyncio
import csv
import itertools
import json
import re
import time
import urllib
@ -15,21 +14,17 @@ from sanic.views import HTTPMethodView
from datasette import __version__
from datasette.plugins import pm
from datasette.utils import (
CustomJSONEncoder,
InterruptedError,
InvalidSql,
LimitedWriter,
format_bytes,
is_url,
path_from_row_pks,
path_with_added_args,
path_with_removed_args,
path_with_format,
remove_infinites,
resolve_table_and_format,
sqlite3,
to_css_class,
value_as_boolean,
)
ureg = pint.UnitRegistry()
@ -127,6 +122,7 @@ class RenderMixin(HTTPMethodView):
class BaseView(RenderMixin):
name = ''
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
def __init__(self, datasette):
@ -184,7 +180,8 @@ class BaseView(RenderMixin):
table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"]
),
table_exists=async_table_exists
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
)
kwargs["table"] = table
if _format:
@ -316,33 +313,43 @@ class BaseView(RenderMixin):
content_type=content_type
)
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL
parameters or from a file extension.
`args` is a dict of the path components parsed from the URL by the router.
"""
# If ?_format= is provided, use that as the format
_format = request.args.get("_format", None)
if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs:
_format = (args.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in args:
async def async_table_exists(t):
return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"]
args["table_and_format"]
),
table_exists=async_table_exists
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
)
_format = _format or _ext_format
kwargs["table"] = table
del kwargs["table_and_format"]
elif "table" in kwargs:
kwargs["table"] = urllib.parse.unquote_plus(
kwargs["table"]
args["table"] = table
del args["table_and_format"]
elif "table" in args:
args["table"] = urllib.parse.unquote_plus(
args["table"]
)
return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
_format, kwargs = await self.get_format(request, database, kwargs)
if _format == "csv":
return await self.as_csv(request, database, hash, **kwargs)
if _format is None:
# HTML views default to expanding all forign key labels
# HTML views default to expanding all foriegn key labels
kwargs['default_labels'] = True
extra_template_data = {}
@ -358,7 +365,7 @@ class BaseView(RenderMixin):
else:
data, extra_template_data, templates = response_or_template_contexts
except InterruptedError as e:
except InterruptedError:
raise DatasetteError("""
SQL query took too long. The time limit is controlled by the
<a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
@ -379,92 +386,37 @@ class BaseView(RenderMixin):
value = self.ds.metadata(key)
if value:
data[key] = value
if _format in ("json", "jsono"):
# Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono":
return self.redirect(
# Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono":
return self.redirect(
request,
path_with_added_args(
request,
path_with_added_args(
request,
{"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json",
),
forward_querystring=False,
)
# Handle the _json= parameter which may modify data["rows"]
json_cols = []
if "_json" in request.args:
json_cols = request.args["_json"]
if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols,
)
# unless _json_infinity=1 requested, replace infinity with None
if "rows" in data and not value_as_boolean(
request.args.get("_json_infinity", "0")
):
data["rows"] = [remove_infinites(row) for row in data["rows"]]
# Deal with the _shape option
shape = request.args.get("_shape", "arrays")
if shape == "arrayfirst":
data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"):
columns = data.get("columns")
rows = data.get("rows")
if rows and columns:
data["rows"] = [dict(zip(columns, row)) for row in rows]
if shape == "object":
error = None
if "primary_keys" not in data:
error = "_shape=object is only available on tables"
else:
pks = data["primary_keys"]
if not pks:
error = "_shape=object not available for tables with no primary keys"
else:
object_rows = {}
for row in data["rows"]:
pk_string = path_from_row_pks(row, pks, not pks)
object_rows[pk_string] = row
data = object_rows
if error:
data = {
"ok": False,
"error": error,
"database": database,
}
elif shape == "array":
data = data["rows"]
elif shape == "arrays":
pass
else:
status_code = 400
data = {
"ok": False,
"error": "Invalid _shape: {}".format(shape),
"status": 400,
"title": None,
}
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
# Handle _nl option for _shape=array
nl = request.args.get("_nl", "")
if nl and shape == "array":
body = "\n".join(json.dumps(item) for item in data)
content_type = "text/plain"
else:
body = json.dumps(data, cls=CustomJSONEncoder)
content_type = "application/json"
r = response.HTTPResponse(
body,
status=status_code,
content_type=content_type,
headers=headers,
{"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json",
),
forward_querystring=False,
)
if _format in self.ds.renderers.keys():
# Dispatch request to the correct output format renderer
# (CSV is not handled here due to streaming)
result = self.ds.renderers[_format](request.args, data, self.name)
if result is None:
raise NotFound("No data")
response_args = {
'content_type': result.get('content_type', 'text/plain'),
'status': result.get('status_code', 200)
}
if type(result.get('body')) == bytes:
response_args['body_bytes'] = result.get('body')
else:
response_args['body'] = result.get('body')
r = response.HTTPResponse(**response_args)
else:
extras = {}
if callable(extra_template_data):
@ -476,6 +428,10 @@ class BaseView(RenderMixin):
url_labels_extra = {}
if data.get("expandable_columns"):
url_labels_extra = {"_labels": "on"}
renderers = {
key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys()
}
url_csv_args = {
"_size": "max",
**url_labels_extra
@ -486,9 +442,7 @@ class BaseView(RenderMixin):
**data,
**extras,
**{
"url_json": path_with_format(request, "json", {
**url_labels_extra,
}),
"renderers": renderers,
"url_csv": url_csv,
"url_csv_path": url_csv_path,
"url_csv_hidden_args": [
@ -504,23 +458,29 @@ class BaseView(RenderMixin):
context["metadata"] = self.ds.metadata
r = self.render(templates, **context)
r.status = status_code
# Set far-future cache expiry
if self.ds.cache_headers and r.status == 200:
ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit():
if correct_hash_provided:
ttl = self.ds.config("default_cache_ttl_hashed")
else:
ttl = self.ds.config("default_cache_ttl")
ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit():
if correct_hash_provided:
ttl = self.ds.config("default_cache_ttl_hashed")
else:
ttl = int(ttl)
ttl = self.ds.config("default_cache_ttl")
return self.set_response_headers(r, ttl)
def set_response_headers(self, response, ttl):
# Set far-future cache expiry
if self.ds.cache_headers and response.status == 200:
ttl = int(ttl)
if ttl == 0:
ttl_header = 'no-cache'
else:
ttl_header = 'max-age={}'.format(ttl)
r.headers["Cache-Control"] = ttl_header
r.headers["Referrer-Policy"] = "no-referrer"
return r
response.headers["Cache-Control"] = ttl_header
response.headers["Referrer-Policy"] = "no-referrer"
if self.ds.cors:
response.headers["Access-Control-Allow-Origin"] = "*"
return response
async def custom_sql(
self, request, database, hash, sql, editable=True, canned_query=None,
@ -611,22 +571,3 @@ class BaseView(RenderMixin):
"columns": columns,
"query": {"sql": sql, "params": params},
}, extra_template, templates
def convert_specific_columns_to_json(rows, columns, json_cols):
json_cols = set(json_cols)
if not json_cols.intersection(columns):
return rows
new_rows = []
for row in rows:
new_row = []
for value, column in zip(row, columns):
if column in json_cols:
try:
value = json.loads(value)
except (TypeError, ValueError) as e:
print(e)
pass
new_row.append(value)
new_rows.append(new_row)
return new_rows