From cf406c075433882b656e340870adf7757976fa4c Mon Sep 17 00:00:00 2001 From: Russ Garrett Date: Thu, 2 May 2019 00:01:56 +0100 Subject: [PATCH] New plugin hook: register_output_renderer hook (#441) Thanks @russss! * Add register_output_renderer hook This changeset refactors out the JSON renderer and then adds a hook and dispatcher system to allow custom output renderers to be registered. The CSV output renderer is untouched because supporting streaming renderers through this system would be significantly more complex, and probably not worthwhile. We can't simply allow hooks to be called at request time because we need a list of supported file extensions when the request is being routed in order to resolve ambiguous database/table names. So, renderers need to be registered at startup. I've tried to make this API independent of Sanic's request/response objects so that this can remain stable during the switch to ASGI. I'm using dictionaries to keep it simple and to make adding additional options in the future easy. Fixes #440 --- datasette/app.py | 28 ++++- datasette/hookspecs.py | 5 + datasette/renderer.py | 96 +++++++++++++++ datasette/templates/query.html | 2 +- datasette/templates/row.html | 2 +- datasette/templates/table.html | 10 +- datasette/utils.py | 9 +- datasette/views/base.py | 213 ++++++++++++--------------------- datasette/views/database.py | 2 + datasette/views/index.py | 1 + datasette/views/table.py | 2 + docs/plugins.rst | 50 ++++++++ tests/test_utils.py | 2 +- 13 files changed, 271 insertions(+), 151 deletions(-) create mode 100644 datasette/renderer.py diff --git a/datasette/app.py b/datasette/app.py index 793bc931..d3a8168e 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -24,6 +24,7 @@ from .views.database import DatabaseDownload, DatabaseView from .views.index import IndexView from .views.special import JsonDataView from .views.table import RowView, TableView +from .renderer import json_renderer from .utils import ( InterruptedError, @@ -207,6 +208,7 @@ class Datasette: self.plugins_dir = plugins_dir self.static_mounts = static_mounts or [] self._config = dict(DEFAULT_CONFIG, **(config or {})) + self.renderers = {} # File extension -> renderer function self.version_note = version_note self.executor = futures.ThreadPoolExecutor( max_workers=self.config("num_sql_threads") @@ -630,6 +632,22 @@ class Datasette: ) return results + def register_renderers(self): + """ Register output renderers which output data in custom formats. """ + # Built-in renderers + self.renderers['json'] = json_renderer + + # Hooks + hook_renderers = [] + for hook in pm.hook.register_output_renderer(datasette=self): + if type(hook) == list: + hook_renderers += hook + else: + hook_renderers.append(hook) + + for renderer in hook_renderers: + self.renderers[renderer['extension']] = renderer['callback'] + def app(self): class TracingSanic(Sanic): @@ -671,6 +689,11 @@ class Datasette: self.jinja_env.filters["to_css_class"] = to_css_class # pylint: disable=no-member pm.hook.prepare_jinja2_environment(env=self.jinja_env) + + self.register_renderers() + # Generate a regex snippet to match all registered renderer file extensions + renderer_regex = "|".join(r"\." + key for key in self.renderers.keys()) + app.add_route(IndexView.as_view(self), r"/") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires app.add_route(favicon, "/favicon.ico") @@ -706,7 +729,8 @@ class Datasette: DatabaseDownload.as_view(self), r"/" ) app.add_route( - DatabaseView.as_view(self), r"/" + DatabaseView.as_view(self), + r"/" ) app.add_route( TableView.as_view(self), @@ -714,7 +738,7 @@ class Datasette: ) app.add_route( RowView.as_view(self), - r"///", + r"///", ) self.register_custom_units() diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 6db95344..576e05b9 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -38,3 +38,8 @@ def publish_subcommand(publish): @hookspec(firstresult=True) def render_cell(value, column, table, database, datasette): "Customize rendering of HTML table cell values" + + +@hookspec +def register_output_renderer(datasette): + "Register a renderer to output data in a different format" diff --git a/datasette/renderer.py b/datasette/renderer.py new file mode 100644 index 00000000..dc9011ce --- /dev/null +++ b/datasette/renderer.py @@ -0,0 +1,96 @@ +import json +from datasette.utils import ( + value_as_boolean, + remove_infinites, + CustomJSONEncoder, + path_from_row_pks, +) + + +def convert_specific_columns_to_json(rows, columns, json_cols): + json_cols = set(json_cols) + if not json_cols.intersection(columns): + return rows + new_rows = [] + for row in rows: + new_row = [] + for value, column in zip(row, columns): + if column in json_cols: + try: + value = json.loads(value) + except (TypeError, ValueError) as e: + print(e) + pass + new_row.append(value) + new_rows.append(new_row) + return new_rows + + +def json_renderer(args, data, view_name): + """ Render a response as JSON """ + status_code = 200 + # Handle the _json= parameter which may modify data["rows"] + json_cols = [] + if "_json" in args: + json_cols = args["_json"] + if json_cols and "rows" in data and "columns" in data: + data["rows"] = convert_specific_columns_to_json( + data["rows"], data["columns"], json_cols + ) + + # unless _json_infinity=1 requested, replace infinity with None + if "rows" in data and not value_as_boolean(args.get("_json_infinity", "0")): + data["rows"] = [remove_infinites(row) for row in data["rows"]] + + # Deal with the _shape option + shape = args.get("_shape", "arrays") + if shape == "arrayfirst": + data = [row[0] for row in data["rows"]] + elif shape in ("objects", "object", "array"): + columns = data.get("columns") + rows = data.get("rows") + if rows and columns: + data["rows"] = [dict(zip(columns, row)) for row in rows] + if shape == "object": + error = None + if "primary_keys" not in data: + error = "_shape=object is only available on tables" + else: + pks = data["primary_keys"] + if not pks: + error = ( + "_shape=object not available for tables with no primary keys" + ) + else: + object_rows = {} + for row in data["rows"]: + pk_string = path_from_row_pks(row, pks, not pks) + object_rows[pk_string] = row + data = object_rows + if error: + data = {"ok": False, "error": error} + elif shape == "array": + data = data["rows"] + elif shape == "arrays": + pass + else: + status_code = 400 + data = { + "ok": False, + "error": "Invalid _shape: {}".format(shape), + "status": 400, + "title": None, + } + # Handle _nl option for _shape=array + nl = args.get("_nl", "") + if nl and shape == "array": + body = "\n".join(json.dumps(item) for item in data) + content_type = "text/plain" + else: + body = json.dumps(data, cls=CustomJSONEncoder) + content_type = "application/json" + return { + "body": body, + "status_code": status_code, + "content_type": content_type + } diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 81901660..b4b759a5 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -47,7 +47,7 @@ {% if display_rows %} - + diff --git a/datasette/templates/row.html b/datasette/templates/row.html index 389b16b2..baffaf96 100644 --- a/datasette/templates/row.html +++ b/datasette/templates/row.html @@ -22,7 +22,7 @@ {% block description_source_license %}{% include "_description_source_license.html" %}{% endblock %} -

This data as .json

+

This data as {% for name, url in renderers.items() %}{{ name }}{{ ", " if not loop.last }}{% endfor %}

{% include custom_rows_and_columns_templates %} diff --git a/datasette/templates/table.html b/datasette/templates/table.html index 1c65aa10..d28f41d6 100644 --- a/datasette/templates/table.html +++ b/datasette/templates/table.html @@ -106,7 +106,7 @@

View and edit SQL

{% endif %} - + {% if suggested_facets %}

@@ -155,10 +155,10 @@

Advanced export

JSON shape: - default, - array, - newline-delimited{% if primary_keys %}, - object + default, + array, + newline-delimited{% if primary_keys %}, + object {% endif %}

diff --git a/datasette/utils.py b/datasette/utils.py index 7ebf4f23..2f5c633e 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -718,17 +718,16 @@ def get_plugins(pm): return plugins -FORMATS = ('csv', 'json', 'jsono') - - -async def resolve_table_and_format(table_and_format, table_exists): +async def resolve_table_and_format(table_and_format, table_exists, allowed_formats=[]): if '.' in table_and_format: # Check if a table exists with this exact name it_exists = await table_exists(table_and_format) if it_exists: return table_and_format, None + # Check if table ends with a known format - for _format in FORMATS: + formats = list(allowed_formats) + ['csv', 'jsono'] + for _format in formats: if table_and_format.endswith(".{}".format(_format)): table = table_and_format[:-(len(_format) + 1)] return table, _format diff --git a/datasette/views/base.py b/datasette/views/base.py index 764ad7dd..ae6aac71 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,7 +1,6 @@ import asyncio import csv import itertools -import json import re import time import urllib @@ -15,21 +14,17 @@ from sanic.views import HTTPMethodView from datasette import __version__ from datasette.plugins import pm from datasette.utils import ( - CustomJSONEncoder, InterruptedError, InvalidSql, LimitedWriter, format_bytes, is_url, - path_from_row_pks, path_with_added_args, path_with_removed_args, path_with_format, - remove_infinites, resolve_table_and_format, sqlite3, to_css_class, - value_as_boolean, ) ureg = pint.UnitRegistry() @@ -127,6 +122,7 @@ class RenderMixin(HTTPMethodView): class BaseView(RenderMixin): + name = '' re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") def __init__(self, datasette): @@ -184,7 +180,8 @@ class BaseView(RenderMixin): table_and_format=urllib.parse.unquote_plus( kwargs["table_and_format"] ), - table_exists=async_table_exists + table_exists=async_table_exists, + allowed_formats=self.ds.renderers.keys() ) kwargs["table"] = table if _format: @@ -316,33 +313,43 @@ class BaseView(RenderMixin): content_type=content_type ) - async def view_get(self, request, database, hash, correct_hash_provided, **kwargs): + async def get_format(self, request, database, args): + """ Determine the format of the response from the request, from URL + parameters or from a file extension. + + `args` is a dict of the path components parsed from the URL by the router. + """ # If ?_format= is provided, use that as the format _format = request.args.get("_format", None) if not _format: - _format = (kwargs.pop("as_format", None) or "").lstrip(".") - if "table_and_format" in kwargs: + _format = (args.pop("as_format", None) or "").lstrip(".") + if "table_and_format" in args: async def async_table_exists(t): return await self.ds.table_exists(database, t) table, _ext_format = await resolve_table_and_format( table_and_format=urllib.parse.unquote_plus( - kwargs["table_and_format"] + args["table_and_format"] ), - table_exists=async_table_exists + table_exists=async_table_exists, + allowed_formats=self.ds.renderers.keys() ) _format = _format or _ext_format - kwargs["table"] = table - del kwargs["table_and_format"] - elif "table" in kwargs: - kwargs["table"] = urllib.parse.unquote_plus( - kwargs["table"] + args["table"] = table + del args["table_and_format"] + elif "table" in args: + args["table"] = urllib.parse.unquote_plus( + args["table"] ) + return _format, args + + async def view_get(self, request, database, hash, correct_hash_provided, **kwargs): + _format, kwargs = await self.get_format(request, database, kwargs) if _format == "csv": return await self.as_csv(request, database, hash, **kwargs) if _format is None: - # HTML views default to expanding all forign key labels + # HTML views default to expanding all foriegn key labels kwargs['default_labels'] = True extra_template_data = {} @@ -358,7 +365,7 @@ class BaseView(RenderMixin): else: data, extra_template_data, templates = response_or_template_contexts - except InterruptedError as e: + except InterruptedError: raise DatasetteError(""" SQL query took too long. The time limit is controlled by the sql_time_limit_ms @@ -379,92 +386,37 @@ class BaseView(RenderMixin): value = self.ds.metadata(key) if value: data[key] = value - if _format in ("json", "jsono"): - # Special case for .jsono extension - redirect to _shape=objects - if _format == "jsono": - return self.redirect( + + # Special case for .jsono extension - redirect to _shape=objects + if _format == "jsono": + return self.redirect( + request, + path_with_added_args( request, - path_with_added_args( - request, - {"_shape": "objects"}, - path=request.path.rsplit(".jsono", 1)[0] + ".json", - ), - forward_querystring=False, - ) - - # Handle the _json= parameter which may modify data["rows"] - json_cols = [] - if "_json" in request.args: - json_cols = request.args["_json"] - if json_cols and "rows" in data and "columns" in data: - data["rows"] = convert_specific_columns_to_json( - data["rows"], data["columns"], json_cols, - ) - - # unless _json_infinity=1 requested, replace infinity with None - if "rows" in data and not value_as_boolean( - request.args.get("_json_infinity", "0") - ): - data["rows"] = [remove_infinites(row) for row in data["rows"]] - - # Deal with the _shape option - shape = request.args.get("_shape", "arrays") - if shape == "arrayfirst": - data = [row[0] for row in data["rows"]] - elif shape in ("objects", "object", "array"): - columns = data.get("columns") - rows = data.get("rows") - if rows and columns: - data["rows"] = [dict(zip(columns, row)) for row in rows] - if shape == "object": - error = None - if "primary_keys" not in data: - error = "_shape=object is only available on tables" - else: - pks = data["primary_keys"] - if not pks: - error = "_shape=object not available for tables with no primary keys" - else: - object_rows = {} - for row in data["rows"]: - pk_string = path_from_row_pks(row, pks, not pks) - object_rows[pk_string] = row - data = object_rows - if error: - data = { - "ok": False, - "error": error, - "database": database, - } - elif shape == "array": - data = data["rows"] - elif shape == "arrays": - pass - else: - status_code = 400 - data = { - "ok": False, - "error": "Invalid _shape: {}".format(shape), - "status": 400, - "title": None, - } - headers = {} - if self.ds.cors: - headers["Access-Control-Allow-Origin"] = "*" - # Handle _nl option for _shape=array - nl = request.args.get("_nl", "") - if nl and shape == "array": - body = "\n".join(json.dumps(item) for item in data) - content_type = "text/plain" - else: - body = json.dumps(data, cls=CustomJSONEncoder) - content_type = "application/json" - r = response.HTTPResponse( - body, - status=status_code, - content_type=content_type, - headers=headers, + {"_shape": "objects"}, + path=request.path.rsplit(".jsono", 1)[0] + ".json", + ), + forward_querystring=False, ) + + if _format in self.ds.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = self.ds.renderers[_format](request.args, data, self.name) + if result is None: + raise NotFound("No data") + + response_args = { + 'content_type': result.get('content_type', 'text/plain'), + 'status': result.get('status_code', 200) + } + + if type(result.get('body')) == bytes: + response_args['body_bytes'] = result.get('body') + else: + response_args['body'] = result.get('body') + + r = response.HTTPResponse(**response_args) else: extras = {} if callable(extra_template_data): @@ -476,6 +428,10 @@ class BaseView(RenderMixin): url_labels_extra = {} if data.get("expandable_columns"): url_labels_extra = {"_labels": "on"} + + renderers = { + key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys() + } url_csv_args = { "_size": "max", **url_labels_extra @@ -486,9 +442,7 @@ class BaseView(RenderMixin): **data, **extras, **{ - "url_json": path_with_format(request, "json", { - **url_labels_extra, - }), + "renderers": renderers, "url_csv": url_csv, "url_csv_path": url_csv_path, "url_csv_hidden_args": [ @@ -504,23 +458,29 @@ class BaseView(RenderMixin): context["metadata"] = self.ds.metadata r = self.render(templates, **context) r.status = status_code - # Set far-future cache expiry - if self.ds.cache_headers and r.status == 200: - ttl = request.args.get("_ttl", None) - if ttl is None or not ttl.isdigit(): - if correct_hash_provided: - ttl = self.ds.config("default_cache_ttl_hashed") - else: - ttl = self.ds.config("default_cache_ttl") + + ttl = request.args.get("_ttl", None) + if ttl is None or not ttl.isdigit(): + if correct_hash_provided: + ttl = self.ds.config("default_cache_ttl_hashed") else: - ttl = int(ttl) + ttl = self.ds.config("default_cache_ttl") + + return self.set_response_headers(r, ttl) + + def set_response_headers(self, response, ttl): + # Set far-future cache expiry + if self.ds.cache_headers and response.status == 200: + ttl = int(ttl) if ttl == 0: ttl_header = 'no-cache' else: ttl_header = 'max-age={}'.format(ttl) - r.headers["Cache-Control"] = ttl_header - r.headers["Referrer-Policy"] = "no-referrer" - return r + response.headers["Cache-Control"] = ttl_header + response.headers["Referrer-Policy"] = "no-referrer" + if self.ds.cors: + response.headers["Access-Control-Allow-Origin"] = "*" + return response async def custom_sql( self, request, database, hash, sql, editable=True, canned_query=None, @@ -611,22 +571,3 @@ class BaseView(RenderMixin): "columns": columns, "query": {"sql": sql, "params": params}, }, extra_template, templates - - -def convert_specific_columns_to_json(rows, columns, json_cols): - json_cols = set(json_cols) - if not json_cols.intersection(columns): - return rows - new_rows = [] - for row in rows: - new_row = [] - for value, column in zip(row, columns): - if column in json_cols: - try: - value = json.loads(value) - except (TypeError, ValueError) as e: - print(e) - pass - new_row.append(value) - new_rows.append(new_row) - return new_rows diff --git a/datasette/views/database.py b/datasette/views/database.py index 9f43980b..d1185eef 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -8,6 +8,7 @@ from .base import BaseView, DatasetteError class DatabaseView(BaseView): + name = 'database' async def data(self, request, database, hash, default_labels=False, _size=None): if request.args.get("sql"): @@ -39,6 +40,7 @@ class DatabaseView(BaseView): class DatabaseDownload(BaseView): + name = 'database_download' async def view_get(self, request, database, hash, correct_hash_present, **kwargs): if not self.ds.config("allow_download"): diff --git a/datasette/views/index.py b/datasette/views/index.py index 70f7e943..4eb116f3 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -9,6 +9,7 @@ from .base import HASH_LENGTH, RenderMixin class IndexView(RenderMixin): + name = 'index' def __init__(self, datasette): self.ds = datasette diff --git a/datasette/views/table.py b/datasette/views/table.py index bc5e775e..87f6d2c6 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -174,6 +174,7 @@ class RowTableShared(BaseView): class TableView(RowTableShared): + name = 'table' async def data(self, request, database, hash, table, default_labels=False, _next=None, _size=None): canned_query = self.ds.get_canned_query(database, table) @@ -778,6 +779,7 @@ class TableView(RowTableShared): class RowView(RowTableShared): + name = 'row' async def data(self, request, database, hash, table, pk_path, default_labels=False): pk_values = urlsafe_components(pk_path) diff --git a/docs/plugins.rst b/docs/plugins.rst index 984e5c95..ae3c4607 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -551,3 +551,53 @@ The ``template``, ``database`` and ``table`` options can be used to return diffe The ``datasette`` instance is provided primarily so that you can consult any plugin configuration options that may have been set, using the ``datasette.plugin_config(plugin_name)`` method documented above. The string that you return from this function will be treated as "safe" for inclusion in a ``