Datasette.render_template() method, closes #577

Pull request #664.
This commit is contained in:
Simon Willison 2020-02-04 12:26:17 -08:00 committed by GitHub
commit 70b915fb4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 107 additions and 83 deletions

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import collections import collections
import hashlib import hashlib
import itertools
import json
import os import os
import re import re
import sys import sys
@ -12,7 +14,8 @@ from pathlib import Path
import click import click
from markupsafe import Markup from markupsafe import Markup
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader, escape
from jinja2.environment import Template
import uvicorn import uvicorn
from .views.base import DatasetteError, ureg, AsgiRouter from .views.base import DatasetteError, ureg, AsgiRouter
@ -27,6 +30,7 @@ from .utils import (
QueryInterrupted, QueryInterrupted,
escape_css_string, escape_css_string,
escape_sqlite, escape_sqlite,
format_bytes,
get_plugins, get_plugins,
module_from_path, module_from_path,
sqlite3, sqlite3,
@ -35,6 +39,7 @@ from .utils import (
from .utils.asgi import ( from .utils.asgi import (
AsgiLifespan, AsgiLifespan,
NotFound, NotFound,
Response,
asgi_static, asgi_static,
asgi_send, asgi_send,
asgi_send_html, asgi_send_html,
@ -526,6 +531,96 @@ class Datasette:
for renderer in hook_renderers: for renderer in hook_renderers:
self.renderers[renderer["extension"]] = renderer["callback"] self.renderers[renderer["extension"]] = renderer["callback"]
async def render_template(
self, templates, context=None, request=None, view_name=None
):
context = context or {}
if isinstance(templates, Template):
template = templates
select_templates = []
else:
if isinstance(templates, str):
templates = [templates]
template = self.jinja_env.select_template(templates)
select_templates = [
"{}{}".format(
"*" if template_name == template.name else "", template_name
)
for template_name in templates
]
body_scripts = []
# pylint: disable=no-member
for script in pm.hook.extra_body_script(
template=template.name,
database=context.get("database"),
table=context.get("table"),
view_name=view_name,
datasette=self,
):
body_scripts.append(Markup(script))
extra_template_vars = {}
# pylint: disable=no-member
for extra_vars in pm.hook.extra_template_vars(
template=template.name,
database=context.get("database"),
table=context.get("table"),
view_name=view_name,
request=request,
datasette=self,
):
if callable(extra_vars):
extra_vars = extra_vars()
if asyncio.iscoroutine(extra_vars):
extra_vars = await extra_vars
assert isinstance(extra_vars, dict), "extra_vars is of type {}".format(
type(extra_vars)
)
extra_template_vars.update(extra_vars)
template_context = {
**context,
**{
"app_css_hash": self.app_css_hash(),
"select_templates": select_templates,
"zip": zip,
"body_scripts": body_scripts,
"format_bytes": format_bytes,
"extra_css_urls": self._asset_urls("extra_css_urls", template, context),
"extra_js_urls": self._asset_urls("extra_js_urls", template, context),
},
**extra_template_vars,
}
return await template.render_async(template_context)
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
for url_or_dict in itertools.chain(
itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self,
)
),
(self.metadata(key) or []),
):
if isinstance(url_or_dict, dict):
url = url_or_dict["url"]
sri = url_or_dict.get("sri")
else:
url = url_or_dict
sri = None
if url in seen_urls:
continue
seen_urls.add(url)
if sri:
yield {"url": url, "sri": sri}
else:
yield {"url": url}
def app(self): def app(self):
"Returns an ASGI app function that serves the whole of Datasette" "Returns an ASGI app function that serves the whole of Datasette"
default_templates = str(app_root / "datasette" / "templates") default_templates = str(app_root / "datasette" / "templates")

View file

@ -9,15 +9,12 @@ import urllib
import jinja2 import jinja2
import pint import pint
from html import escape
from datasette import __version__ from datasette import __version__
from datasette.plugins import pm from datasette.plugins import pm
from datasette.utils import ( from datasette.utils import (
QueryInterrupted, QueryInterrupted,
InvalidSql, InvalidSql,
LimitedWriter, LimitedWriter,
format_bytes,
is_url, is_url,
path_with_added_args, path_with_added_args,
path_with_removed_args, path_with_removed_args,
@ -65,34 +62,6 @@ class BaseView(AsgiView):
response.body = b"" response.body = b""
return response return response
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
for url_or_dict in itertools.chain(
itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds,
)
),
(self.ds.metadata(key) or []),
):
if isinstance(url_or_dict, dict):
url = url_or_dict["url"]
sri = url_or_dict.get("sri")
else:
url = url_or_dict
sri = None
if url in seen_urls:
continue
seen_urls.add(url)
if sri:
yield {"url": url, "sri": sri}
else:
yield {"url": url}
def database_url(self, database): def database_url(self, database):
db = self.ds.databases[database] db = self.ds.databases[database]
if self.ds.config("hash_urls") and db.hash: if self.ds.config("hash_urls") and db.hash:
@ -105,62 +74,22 @@ class BaseView(AsgiView):
async def render(self, templates, request, context): async def render(self, templates, request, context):
template = self.ds.jinja_env.select_template(templates) template = self.ds.jinja_env.select_template(templates)
select_templates = [
"{}{}".format("*" if template_name == template.name else "", template_name)
for template_name in templates
]
body_scripts = []
# pylint: disable=no-member
for script in pm.hook.extra_body_script(
template=template.name,
database=context.get("database"),
table=context.get("table"),
view_name=self.name,
datasette=self.ds,
):
body_scripts.append(jinja2.Markup(script))
extra_template_vars = {}
# pylint: disable=no-member
for extra_vars in pm.hook.extra_template_vars(
template=template.name,
database=context.get("database"),
table=context.get("table"),
view_name=self.name,
request=request,
datasette=self.ds,
):
if callable(extra_vars):
extra_vars = extra_vars()
if asyncio.iscoroutine(extra_vars):
extra_vars = await extra_vars
assert isinstance(extra_vars, dict), "extra_vars is of type {}".format(
type(extra_vars)
)
extra_template_vars.update(extra_vars)
template_context = { template_context = {
**context, **context,
**{ **{
"app_css_hash": self.ds.app_css_hash(), "database_url": self.database_url,
"select_templates": select_templates, "database_color": self.database_color,
"zip": zip, },
"body_scripts": body_scripts, }
"extra_css_urls": self._asset_urls("extra_css_urls", template, context), if request and request.args.get("_context") and self.ds.config("template_debug"):
"extra_js_urls": self._asset_urls("extra_js_urls", template, context),
"format_bytes": format_bytes,
"database_url": self.database_url,
"database_color": self.database_color,
},
**extra_template_vars,
}
if request.args.get("_context") and self.ds.config("template_debug"):
return Response.html( return Response.html(
"<pre>{}</pre>".format( "<pre>{}</pre>".format(
escape(json.dumps(template_context, default=repr, indent=4)) jinja2.escape(json.dumps(template_context, default=repr, indent=4))
) )
) )
return Response.html(await template.render_async(template_context)) return Response.html(await self.ds.render_template(
template, template_context, request=request
))
class DataView(BaseView): class DataView(BaseView):