diff --git a/datasette/views/base.py b/datasette/views/base.py index a3a207bd..30026f4b 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,32 +1,19 @@ -import asyncio import csv -import dataclasses import hashlib import sys -import textwrap -import time -import urllib -from markupsafe import escape - -from datasette.database import QueryInterrupted from datasette.utils.asgi import Request from datasette.utils import ( add_cors_headers, - await_me_maybe, EscapeHtmlWriter, InvalidSql, LimitedWriter, - call_with_supported_arguments, path_from_row_pks, - path_with_added_args, - path_with_removed_args, path_with_format, sqlite3, ) from datasette.utils.asgi import ( AsgiStream, - NotFound, Response, BadRequest, ) @@ -89,9 +76,6 @@ class View: class BaseView: ds = None has_json_alternate = True - # Set to a Context subclass to render a documented template context - - # keys not declared on the class are dropped before rendering - context_class = None def __init__(self, datasette): self.ds = datasette @@ -173,11 +157,6 @@ class BaseView: ) } ) - if self.context_class is not None: - declared = {f.name for f in dataclasses.fields(self.context_class)} - template_context = self.context_class( - **{k: v for k, v in template_context.items() if k in declared} - ) return Response.html( await self.ds.render_template( template, @@ -201,227 +180,6 @@ class BaseView: return view -class DataView(BaseView): - name = "" - - def redirect(self, request, path, forward_querystring=True, remove_args=None): - if request.query_string and "?" not in path and forward_querystring: - path = f"{path}?{request.query_string}" - if remove_args: - path = path_with_removed_args(request, remove_args, path=path) - r = Response.redirect(path) - r.headers["Link"] = f"<{path}>; rel=preload" - if self.ds.cors: - add_cors_headers(r.headers) - return r - - async def data(self, request): - raise NotImplementedError - - async def as_csv(self, request, database): - return await stream_csv(self.ds, self.data, request, database) - - async def get(self, request): - db = await self.ds.resolve_database(request) - database = db.name - database_route = db.route - - _format = request.url_vars["format"] - data_kwargs = {} - - if _format == "csv": - return await self.as_csv(request, database_route) - - if _format is None: - # HTML views default to expanding all foreign key labels - data_kwargs["default_labels"] = True - - extra_template_data = {} - start = time.perf_counter() - status_code = None - templates = [] - try: - response_or_template_contexts = await self.data(request, **data_kwargs) - if isinstance(response_or_template_contexts, Response): - return response_or_template_contexts - # If it has four items, it includes an HTTP status code - if len(response_or_template_contexts) == 4: - ( - data, - extra_template_data, - templates, - status_code, - ) = response_or_template_contexts - else: - data, extra_template_data, templates = response_or_template_contexts - except QueryInterrupted as ex: - raise DatasetteError( - textwrap.dedent(""" -

SQL query took too long. The time limit is controlled by the - sql_time_limit_ms - configuration option.

- - - """.format(escape(ex.sql))).strip(), - title="SQL Interrupted", - status=400, - message_is_html=True, - ) - except (sqlite3.OperationalError, InvalidSql) as e: - raise DatasetteError(str(e), title="Invalid SQL", status=400) - - except sqlite3.OperationalError as e: - raise DatasetteError(str(e)) - - except DatasetteError: - raise - - end = time.perf_counter() - data["query_ms"] = (end - start) * 1000 - - # Special case for .jsono extension - redirect to _shape=objects - if _format == "jsono": - return self.redirect( - request, - path_with_added_args( - request, - {"_shape": "objects"}, - path=request.path.rsplit(".jsono", 1)[0] + ".json", - ), - forward_querystring=False, - ) - - if _format in self.ds.renderers.keys(): - # Dispatch request to the correct output format renderer - # (CSV is not handled here due to streaming) - result = call_with_supported_arguments( - self.ds.renderers[_format][0], - datasette=self.ds, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name=self.name, - truncated=False, # TODO: support this - error=data.get("error"), - # These will be deprecated in Datasette 1.0: - args=request.args, - data=data, - ) - if asyncio.iscoroutine(result): - result = await result - if result is None: - raise NotFound("No data") - if isinstance(result, dict): - r = Response( - body=result.get("body"), - status=result.get("status_code", status_code or 200), - content_type=result.get("content_type", "text/plain"), - headers=result.get("headers"), - ) - elif isinstance(result, Response): - r = result - if status_code is not None: - # Over-ride the status code - r.status = status_code - else: - assert False, f"{result} should be dict or Response" - else: - extras = {} - if callable(extra_template_data): - extras = extra_template_data() - if asyncio.iscoroutine(extras): - extras = await extras - else: - extras = extra_template_data - url_labels_extra = {} - if data.get("expandable_columns"): - url_labels_extra = {"_labels": "on"} - - renderers = {} - for key, (_, can_render) in self.ds.renderers.items(): - it_can_render = call_with_supported_arguments( - can_render, - datasette=self.ds, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name=self.name, - ) - it_can_render = await await_me_maybe(it_can_render) - if it_can_render: - renderers[key] = self.ds.urls.path( - path_with_format( - request=request, - path=request.scope.get("route_path"), - format=key, - extra_qs={**url_labels_extra}, - ) - ) - - url_csv_args = {"_size": "max", **url_labels_extra} - url_csv = self.ds.urls.path( - path_with_format( - request=request, - path=request.scope.get("route_path"), - format="csv", - extra_qs=url_csv_args, - ) - ) - url_csv_path = url_csv.split("?")[0] - context = { - **data, - **extras, - **{ - "renderers": renderers, - "url_csv": url_csv, - "url_csv_path": url_csv_path, - "url_csv_hidden_args": [ - (key, value) - for key, value in urllib.parse.parse_qsl(request.query_string) - if key not in ("_labels", "_facet", "_size") - ] - + [("_size", "max")], - "settings": self.ds.settings_dict(), - }, - } - if "metadata" not in context: - context["metadata"] = await self.ds.get_instance_metadata() - r = await self.render(templates, request=request, context=context) - if status_code is not None: - r.status = status_code - - ttl = request.args.get("_ttl", None) - if ttl is None or not ttl.isdigit(): - ttl = self.ds.setting("default_cache_ttl") - - return self.set_response_headers(r, ttl) - - def set_response_headers(self, response, ttl): - # Set far-future cache expiry - if self.ds.cache_headers and response.status == 200: - ttl = int(ttl) - if ttl == 0: - ttl_header = "no-cache" - else: - ttl_header = f"max-age={ttl}" - response.headers["Cache-Control"] = ttl_header - response.headers["Referrer-Policy"] = "no-referrer" - if self.ds.cors: - add_cors_headers(response.headers) - return response - - def _error(messages, status=400): return Response.json({"ok": False, "errors": messages}, status=status) diff --git a/datasette/views/row.py b/datasette/views/row.py index 7802f45e..3e3e52a9 100644 --- a/datasette/views/row.py +++ b/datasette/views/row.py @@ -1,21 +1,34 @@ +import asyncio +import json +import textwrap +import time +import urllib.parse +from dataclasses import dataclass, field, fields + +import markupsafe +import sqlite_utils + from datasette.utils.asgi import NotFound, Forbidden, Response from datasette.database import QueryInterrupted from datasette.events import UpdateRowEvent, DeleteRowEvent from datasette.resources import TableResource -from .base import DataView, BaseView, _error +from .base import BaseView, DatasetteError, _error, stream_csv from datasette.utils import ( + add_cors_headers, await_me_maybe, + call_with_supported_arguments, CustomRow, + InvalidSql, make_slot_function, path_from_row_pks, + path_with_added_args, + path_with_format, + path_with_removed_args, to_css_class, escape_sqlite, + sqlite3, ) from datasette.plugins import pm -from dataclasses import dataclass, field -import json -import markupsafe -import sqlite_utils from datasette.extras import extra_names_from_request, ExtraScope from . import Context, extra_field from .table import ( @@ -121,9 +134,259 @@ class RowContext(Context): ) -class RowView(DataView): +class RowView(BaseView): name = "row" - context_class = RowContext + + def redirect(self, request, path, forward_querystring=True, remove_args=None): + if request.query_string and "?" not in path and forward_querystring: + path = f"{path}?{request.query_string}" + if remove_args: + path = path_with_removed_args(request, remove_args, path=path) + response = Response.redirect(path) + response.headers["Link"] = f"<{path}>; rel=preload" + if self.ds.cors: + add_cors_headers(response.headers) + return response + + async def as_csv(self, request, database): + return await stream_csv(self.ds, self.data, request, database) + + async def get(self, request): + db = await self.ds.resolve_database(request) + database = db.name + database_route = db.route + format_ = request.url_vars.get("format") or "html" + data_kwargs = {} + + if format_ == "csv": + return await self.as_csv(request, database_route) + + if format_ == "html": + # HTML views default to expanding all foreign key labels + data_kwargs["default_labels"] = True + + extra_template_data = {} + start = time.perf_counter() + status_code = None + templates = () + try: + response_or_template_contexts = await self.data(request, **data_kwargs) + if isinstance(response_or_template_contexts, Response): + return response_or_template_contexts + # If it has four items, it includes an HTTP status code + if len(response_or_template_contexts) == 4: + ( + data, + extra_template_data, + templates, + status_code, + ) = response_or_template_contexts + else: + data, extra_template_data, templates = response_or_template_contexts + except QueryInterrupted as ex: + raise DatasetteError( + textwrap.dedent(""" +

SQL query took too long. The time limit is controlled by the + sql_time_limit_ms + configuration option.

+ + + """.format(markupsafe.escape(ex.sql))).strip(), + title="SQL Interrupted", + status=400, + message_is_html=True, + ) + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + except sqlite3.OperationalError as e: + raise DatasetteError(str(e)) + except DatasetteError: + raise + + end = time.perf_counter() + data["query_ms"] = (end - start) * 1000 + + # Special case for .jsono extension - redirect to _shape=objects + if format_ == "jsono": + return self.redirect( + request, + path_with_added_args( + request, + {"_shape": "objects"}, + path=request.path.rsplit(".jsono", 1)[0] + ".json", + ), + forward_querystring=False, + ) + + if format_ in self.ds.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + self.ds.renderers[format_][0], + datasette=self.ds, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=database, + table=data.get("table"), + request=request, + view_name=self.name, + truncated=False, # TODO: support this + error=data.get("error"), + # These will be deprecated in Datasette 1.0: + args=request.args, + data=data, + ) + if asyncio.iscoroutine(result): + result = await result + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + response = Response( + body=result.get("body"), + status=result.get("status_code", status_code or 200), + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + response = result + if status_code is not None: + # Over-ride the status code + response.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + response = await self.html(request, data, extra_template_data, templates) + if status_code is not None: + response.status = status_code + else: + raise NotFound("Invalid format: {}".format(format_)) + + ttl = request.args.get("_ttl", None) + if ttl is None or not ttl.isdigit(): + ttl = self.ds.setting("default_cache_ttl") + + return self.set_response_headers(response, ttl) + + async def html(self, request, data, extra_template_data, templates): + extras = {} + if callable(extra_template_data): + extras = extra_template_data() + if asyncio.iscoroutine(extras): + extras = await extras + else: + extras = extra_template_data + + url_labels_extra = {} + if data.get("expandable_columns"): + url_labels_extra = {"_labels": "on"} + + renderers = {} + for key, (_, can_render) in self.ds.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=self.ds, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=data.get("database"), + table=data.get("table"), + request=request, + view_name=self.name, + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = self.ds.urls.path( + path_with_format( + request=request, + path=request.scope.get("route_path"), + format=key, + extra_qs={**url_labels_extra}, + ) + ) + + url_csv_args = {"_size": "max", **url_labels_extra} + url_csv = self.ds.urls.path( + path_with_format( + request=request, + path=request.scope.get("route_path"), + format="csv", + extra_qs=url_csv_args, + ) + ) + url_csv_path = url_csv.split("?")[0] + context = {**data, **extras} + if "metadata" not in context: + context["metadata"] = await self.ds.get_instance_metadata() + + environment = self.ds.get_jinja_environment(request) + template = environment.select_template(templates) + alternate_url_json = self.ds.absolute_url( + request, + self.ds.urls.path( + path_with_format( + request=request, + path=request.scope.get("route_path"), + format="json", + ) + ), + ) + explicit_context = { + "renderers": renderers, + "url_csv": url_csv, + "url_csv_path": url_csv_path, + "url_csv_hidden_args": [ + (key, value) + for key, value in urllib.parse.parse_qsl(request.query_string) + if key not in ("_labels", "_facet", "_size") + ] + + [("_size", "max")], + "settings": self.ds.settings_dict(), + "alternate_url_json": alternate_url_json, + "select_templates": [ + f"{'*' if template_name == template.name else ''}{template_name}" + for template_name in templates + ], + } + declared_fields = {f.name for f in fields(RowContext)} + context = { + key: value + for key, value in context.items() + if key in declared_fields and key not in explicit_context + } + + return Response.html( + await self.ds.render_template( + template, + RowContext(**context, **explicit_context), + request=request, + view_name=self.name, + ), + headers={ + "Link": '<{}>; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + }, + ) + + def set_response_headers(self, response, ttl): + # Set far-future cache expiry + if self.ds.cache_headers and response.status == 200: + ttl = int(ttl) + if ttl == 0: + ttl_header = "no-cache" + else: + ttl_header = f"max-age={ttl}" + response.headers["Cache-Control"] = ttl_header + response.headers["Referrer-Policy"] = "no-referrer" + if self.ds.cors: + add_cors_headers(response.headers) + return response async def data(self, request, default_labels=False): resolved = await self.ds.resolve_row(request)