Refactored canned query code, replaced old QueryView, closes #2114

This commit is contained in:
Simon Willison 2023-08-09 08:26:52 -07:00
commit 26be9f0445
4 changed files with 343 additions and 571 deletions

View file

@ -24,7 +24,7 @@
{% block content %} {% block content %}
{% if canned_write and db_is_immutable %} {% if canned_query_write and db_is_immutable %}
<p class="message-error">This query cannot be executed because the database is immutable.</p> <p class="message-error">This query cannot be executed because the database is immutable.</p>
{% endif %} {% endif %}
@ -32,7 +32,7 @@
{% block description_source_license %}{% include "_description_source_license.html" %}{% endblock %} {% block description_source_license %}{% include "_description_source_license.html" %}{% endblock %}
<form class="sql" action="{{ urls.database(database) }}{% if canned_query %}/{{ canned_query }}{% endif %}" method="{% if canned_write %}post{% else %}get{% endif %}"> <form class="sql" action="{{ urls.database(database) }}{% if canned_query %}/{{ canned_query }}{% endif %}" method="{% if canned_query_write %}post{% else %}get{% endif %}">
<h3>Custom SQL query{% if display_rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(display_rows|length) }} row{% if display_rows|length == 1 %}{% else %}s{% endif %}{% endif %}{% if not query_error %} <h3>Custom SQL query{% if display_rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(display_rows|length) }} row{% if display_rows|length == 1 %}{% else %}s{% endif %}{% endif %}{% if not query_error %}
<span class="show-hide-sql">(<a href="{{ show_hide_link }}">{{ show_hide_text }}</a>)</span> <span class="show-hide-sql">(<a href="{{ show_hide_link }}">{{ show_hide_text }}</a>)</span>
{% endif %}</h3> {% endif %}</h3>
@ -61,8 +61,8 @@
{% endif %} {% endif %}
<p> <p>
{% if not hide_sql %}<button id="sql-format" type="button" hidden>Format SQL</button>{% endif %} {% if not hide_sql %}<button id="sql-format" type="button" hidden>Format SQL</button>{% endif %}
{% if canned_write %}<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">{% endif %} {% if canned_query_write %}<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">{% endif %}
<input type="submit" value="Run SQL"{% if canned_write and db_is_immutable %} disabled{% endif %}> <input type="submit" value="Run SQL"{% if canned_query_write and db_is_immutable %} disabled{% endif %}>
{{ show_hide_hidden }} {{ show_hide_hidden }}
{% if canned_query and edit_sql_url %}<a href="{{ edit_sql_url }}" class="canned-query-edit-sql">Edit SQL</a>{% endif %} {% if canned_query and edit_sql_url %}<a href="{{ edit_sql_url }}" class="canned-query-edit-sql">Edit SQL</a>{% endif %}
</p> </p>
@ -87,7 +87,7 @@
</tbody> </tbody>
</table></div> </table></div>
{% else %} {% else %}
{% if not canned_write and not error %} {% if not canned_query_write and not error %}
<p class="zero-results">0 results</p> <p class="zero-results">0 results</p>
{% endif %} {% endif %}
{% endif %} {% endif %}

View file

@ -1,4 +1,3 @@
from asyncinject import Registry
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable from typing import Callable
from urllib.parse import parse_qsl, urlencode from urllib.parse import parse_qsl, urlencode
@ -33,7 +32,7 @@ from datasette.utils import (
from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden
from datasette.plugins import pm from datasette.plugins import pm
from .base import BaseView, DatasetteError, DataView, View, _error, stream_csv from .base import BaseView, DatasetteError, View, _error, stream_csv
class DatabaseView(View): class DatabaseView(View):
@ -57,7 +56,7 @@ class DatabaseView(View):
sql = (request.args.get("sql") or "").strip() sql = (request.args.get("sql") or "").strip()
if sql: if sql:
return await query_view(request, datasette) return await QueryView()(request, datasette)
if format_ not in ("html", "json"): if format_ not in ("html", "json"):
raise NotFound("Invalid format: {}".format(format_)) raise NotFound("Invalid format: {}".format(format_))
@ -65,10 +64,6 @@ class DatabaseView(View):
metadata = (datasette.metadata("databases") or {}).get(database, {}) metadata = (datasette.metadata("databases") or {}).get(database, {})
datasette.update_with_inherited_metadata(metadata) datasette.update_with_inherited_metadata(metadata)
table_counts = await db.table_counts(5)
hidden_table_names = set(await db.hidden_table_names())
all_foreign_keys = await db.get_all_foreign_keys()
sql_views = [] sql_views = []
for view_name in await db.view_names(): for view_name in await db.view_names():
view_visible, view_private = await datasette.check_visibility( view_visible, view_private = await datasette.check_visibility(
@ -196,8 +191,13 @@ class QueryContext:
# urls: dict = field( # urls: dict = field(
# metadata={"help": "Object containing URL helpers like `database()`"} # metadata={"help": "Object containing URL helpers like `database()`"}
# ) # )
canned_write: bool = field( canned_query_write: bool = field(
metadata={"help": "Boolean indicating if this canned query allows writes"} metadata={
"help": "Boolean indicating if this is a canned query that allows writes"
}
)
metadata: dict = field(
metadata={"help": "Metadata about the database or the canned query"}
) )
db_is_immutable: bool = field( db_is_immutable: bool = field(
metadata={"help": "Boolean indicating if this database is immutable"} metadata={"help": "Boolean indicating if this database is immutable"}
@ -232,7 +232,6 @@ class QueryContext:
show_hide_hidden: str = field( show_hide_hidden: str = field(
metadata={"help": "Hidden input field for the _show_sql parameter"} metadata={"help": "Hidden input field for the _show_sql parameter"}
) )
metadata: dict = field(metadata={"help": "Metadata about the query/database"})
database_color: Callable = field( database_color: Callable = field(
metadata={"help": "Function that returns a color for a given database name"} metadata={"help": "Function that returns a color for a given database name"}
) )
@ -242,6 +241,12 @@ class QueryContext:
alternate_url_json: str = field( alternate_url_json: str = field(
metadata={"help": "URL for alternate JSON version of this page"} metadata={"help": "URL for alternate JSON version of this page"}
) )
# TODO: refactor this to somewhere else, probably ds.render_template()
select_templates: list = field(
metadata={
"help": "List of templates that were considered for rendering this page"
}
)
async def get_tables(datasette, request, db): async def get_tables(datasette, request, db):
@ -320,48 +325,167 @@ async def database_download(request, datasette):
) )
async def query_view( class QueryView(View):
request, async def post(self, request, datasette):
datasette, from datasette.app import TableNotFound
# canned_query=None,
# _size=None, db = await datasette.resolve_database(request)
# named_parameters=None,
# write=False, # We must be a canned query
): table_found = False
try:
await datasette.resolve_table(request)
table_found = True
except TableNotFound as table_not_found:
canned_query = await datasette.get_canned_query(
table_not_found.database_name, table_not_found.table, request.actor
)
if canned_query is None:
raise
if table_found:
# That should not have happened
raise DatasetteError("Unexpected table found on POST", status=404)
# If database is immutable, return an error
if not db.is_mutable:
raise Forbidden("Database is immutable")
# Process the POST
body = await request.post_body()
body = body.decode("utf-8").strip()
if body.startswith("{") and body.endswith("}"):
params = json.loads(body)
# But we want key=value strings
for key, value in params.items():
params[key] = str(value)
else:
params = dict(parse_qsl(body, keep_blank_values=True))
# Should we return JSON?
should_return_json = (
request.headers.get("accept") == "application/json"
or request.args.get("_json")
or params.get("_json")
)
params_for_query = MagicParameters(params, request, datasette)
ok = None
redirect_url = None
try:
cursor = await db.execute_write(canned_query["sql"], params_for_query)
message = canned_query.get(
"on_success_message"
) or "Query executed, {} row{} affected".format(
cursor.rowcount, "" if cursor.rowcount == 1 else "s"
)
message_type = datasette.INFO
redirect_url = canned_query.get("on_success_redirect")
ok = True
except Exception as ex:
message = canned_query.get("on_error_message") or str(ex)
message_type = datasette.ERROR
redirect_url = canned_query.get("on_error_redirect")
ok = False
if should_return_json:
return Response.json(
{
"ok": ok,
"message": message,
"redirect": redirect_url,
}
)
else:
datasette.add_message(request, message, message_type)
return Response.redirect(redirect_url or request.path)
async def get(self, request, datasette):
from datasette.app import TableNotFound
db = await datasette.resolve_database(request) db = await datasette.resolve_database(request)
database = db.name database = db.name
# Flattened because of ?sql=&name1=value1&name2=value2 feature
params = {key: request.args.get(key) for key in request.args}
sql = None
if "sql" in params:
sql = params.pop("sql")
if "_shape" in params:
params.pop("_shape")
# extras come from original request.args to avoid being flattened # Are we a canned query?
extras = request.args.getlist("_extra") canned_query = None
canned_query_write = False
if "table" in request.url_vars:
try:
await datasette.resolve_table(request)
except TableNotFound as table_not_found:
# Was this actually a canned query?
canned_query = await datasette.get_canned_query(
table_not_found.database_name, table_not_found.table, request.actor
)
if canned_query is None:
raise
canned_query_write = bool(canned_query.get("write"))
# TODO: Behave differently for canned query here: private = False
await datasette.ensure_permissions(request.actor, [("execute-sql", database)]) if canned_query:
# Respect canned query permissions
_, private = await datasette.check_visibility( visible, private = await datasette.check_visibility(
request.actor, request.actor,
permissions=[ permissions=[
("view-query", (database, canned_query["name"])),
("view-database", database), ("view-database", database),
"view-instance", "view-instance",
], ],
) )
if not visible:
raise Forbidden("You do not have permission to view this query")
else:
await datasette.ensure_permissions(
request.actor, [("execute-sql", database)]
)
# Flattened because of ?sql=&name1=value1&name2=value2 feature
params = {key: request.args.get(key) for key in request.args}
sql = None
if canned_query:
sql = canned_query["sql"]
elif "sql" in params:
sql = params.pop("sql")
# Extract any :named parameters
named_parameters = []
if canned_query and canned_query.get("params"):
named_parameters = canned_query["params"]
if not named_parameters:
named_parameters = await derive_named_parameters(
datasette.get_database(database), sql
)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters
if not named_parameter.startswith("_")
}
# Set to blank string if missing from params
for named_parameter in named_parameters:
if named_parameter not in params and not named_parameter.startswith("_"):
params[named_parameter] = ""
extra_args = {} extra_args = {}
if params.get("_timelimit"): if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"]) extra_args["custom_time_limit"] = int(params["_timelimit"])
format_ = request.url_vars.get("format") or "html" format_ = request.url_vars.get("format") or "html"
query_error = None query_error = None
results = None
rows = []
columns = []
params_for_query = params
if not canned_query_write:
try: try:
if not canned_query:
# For regular queries we only allow SELECT, plus other rules
validate_sql_select(sql) validate_sql_select(sql)
else:
# Canned queries can run magic parameters
params_for_query = MagicParameters(params, request, datasette)
results = await datasette.execute( results = await datasette.execute(
database, sql, params, truncate=True, **extra_args database, sql, params_for_query, truncate=True, **extra_args
) )
columns = results.columns columns = results.columns
rows = results.rows rows = results.rows
@ -415,7 +539,7 @@ async def query_view(
columns=columns, columns=columns,
rows=rows, rows=rows,
sql=sql, sql=sql,
query_name=None, query_name=canned_query["name"] if canned_query else None,
database=database, database=database,
table=None, table=None,
request=request, request=request,
@ -447,6 +571,12 @@ async def query_view(
elif format_ == "html": elif format_ == "html":
headers = {} headers = {}
templates = [f"query-{to_css_class(database)}.html", "query.html"] templates = [f"query-{to_css_class(database)}.html", "query.html"]
if canned_query:
templates.insert(
0,
f"query-{to_css_class(database)}-{to_css_class(canned_query['name'])}.html",
)
template = datasette.jinja_env.select_template(templates) template = datasette.jinja_env.select_template(templates)
alternate_url_json = datasette.absolute_url( alternate_url_json = datasette.absolute_url(
request, request,
@ -488,353 +618,7 @@ async def query_view(
) )
show_hide_hidden = "" show_hide_hidden = ""
if metadata.get("hide_sql"): if canned_query and canned_query.get("hide_sql"):
if bool(params.get("_show_sql")):
show_hide_link = path_with_removed_args(request, {"_show_sql"})
show_hide_text = "hide"
show_hide_hidden = '<input type="hidden" name="_show_sql" value="1">'
else:
show_hide_link = path_with_added_args(request, {"_show_sql": 1})
show_hide_text = "show"
else:
if bool(params.get("_hide_sql")):
show_hide_link = path_with_removed_args(request, {"_hide_sql"})
show_hide_text = "show"
show_hide_hidden = '<input type="hidden" name="_hide_sql" value="1">'
else:
show_hide_link = path_with_added_args(request, {"_hide_sql": 1})
show_hide_text = "hide"
hide_sql = show_hide_text == "show"
# Extract any :named parameters
named_parameters = await derive_named_parameters(
datasette.get_database(database), sql
)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters
if not named_parameter.startswith("_")
}
# Set to blank string if missing from params
for named_parameter in named_parameters:
if named_parameter not in params and not named_parameter.startswith("_"):
params[named_parameter] = ""
r = Response.html(
await datasette.render_template(
template,
QueryContext(
database=database,
query={
"sql": sql,
"params": params,
},
canned_query=None,
private=private,
canned_write=False,
db_is_immutable=not db.is_mutable,
error=query_error,
hide_sql=hide_sql,
show_hide_link=datasette.urls.path(show_hide_link),
show_hide_text=show_hide_text,
editable=True, # TODO
allow_execute_sql=allow_execute_sql,
tables=await get_tables(datasette, request, db),
named_parameter_values=named_parameter_values,
edit_sql_url="todo",
display_rows=await display_rows(
datasette, database, request, rows, columns
),
table_columns=await _table_columns(datasette, database)
if allow_execute_sql
else {},
columns=columns,
renderers=renderers,
url_csv=datasette.urls.path(
path_with_format(
request=request, format="csv", extra_qs={"_size": "max"}
)
),
show_hide_hidden=markupsafe.Markup(show_hide_hidden),
metadata=metadata,
database_color=lambda _: "#ff0000",
alternate_url_json=alternate_url_json,
),
request=request,
view_name="database",
),
headers=headers,
)
else:
assert False, "Invalid format: {}".format(format_)
if datasette.cors:
add_cors_headers(r.headers)
return r
class QueryView(DataView):
async def data(
self,
request,
sql,
editable=True,
canned_query=None,
metadata=None,
_size=None,
named_parameters=None,
write=False,
default_labels=None,
):
db = await self.ds.resolve_database(request)
database = db.name
params = {key: request.args.get(key) for key in request.args}
if "sql" in params:
params.pop("sql")
if "_shape" in params:
params.pop("_shape")
private = False
if canned_query:
# Respect canned query permissions
visible, private = await self.ds.check_visibility(
request.actor,
permissions=[
("view-query", (database, canned_query)),
("view-database", database),
"view-instance",
],
)
if not visible:
raise Forbidden("You do not have permission to view this query")
else:
await self.ds.ensure_permissions(request.actor, [("execute-sql", database)])
# Extract any :named parameters
named_parameters = named_parameters or await derive_named_parameters(
self.ds.get_database(database), sql
)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters
if not named_parameter.startswith("_")
}
# Set to blank string if missing from params
for named_parameter in named_parameters:
if named_parameter not in params and not named_parameter.startswith("_"):
params[named_parameter] = ""
extra_args = {}
if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"])
if _size:
extra_args["page_size"] = _size
templates = [f"query-{to_css_class(database)}.html", "query.html"]
if canned_query:
templates.insert(
0,
f"query-{to_css_class(database)}-{to_css_class(canned_query)}.html",
)
query_error = None
# Execute query - as write or as read
if write:
if request.method == "POST":
# If database is immutable, return an error
if not db.is_mutable:
raise Forbidden("Database is immutable")
body = await request.post_body()
body = body.decode("utf-8").strip()
if body.startswith("{") and body.endswith("}"):
params = json.loads(body)
# But we want key=value strings
for key, value in params.items():
params[key] = str(value)
else:
params = dict(parse_qsl(body, keep_blank_values=True))
# Should we return JSON?
should_return_json = (
request.headers.get("accept") == "application/json"
or request.args.get("_json")
or params.get("_json")
)
if canned_query:
params_for_query = MagicParameters(params, request, self.ds)
else:
params_for_query = params
ok = None
try:
cursor = await self.ds.databases[database].execute_write(
sql, params_for_query
)
message = metadata.get(
"on_success_message"
) or "Query executed, {} row{} affected".format(
cursor.rowcount, "" if cursor.rowcount == 1 else "s"
)
message_type = self.ds.INFO
redirect_url = metadata.get("on_success_redirect")
ok = True
except Exception as e:
message = metadata.get("on_error_message") or str(e)
message_type = self.ds.ERROR
redirect_url = metadata.get("on_error_redirect")
ok = False
if should_return_json:
return Response.json(
{
"ok": ok,
"message": message,
"redirect": redirect_url,
}
)
else:
self.ds.add_message(request, message, message_type)
return self.redirect(request, redirect_url or request.path)
else:
async def extra_template():
return {
"request": request,
"db_is_immutable": not db.is_mutable,
"path_with_added_args": path_with_added_args,
"path_with_removed_args": path_with_removed_args,
"named_parameter_values": named_parameter_values,
"canned_query": canned_query,
"success_message": request.args.get("_success") or "",
"canned_write": True,
}
return (
{
"database": database,
"rows": [],
"truncated": False,
"columns": [],
"query": {"sql": sql, "params": params},
"private": private,
},
extra_template,
templates,
)
else: # Not a write
if canned_query:
params_for_query = MagicParameters(params, request, self.ds)
else:
params_for_query = params
try:
results = await self.ds.execute(
database, sql, params_for_query, truncate=True, **extra_args
)
columns = [r[0] for r in results.description]
except sqlite3.DatabaseError as e:
query_error = e
results = None
columns = []
allow_execute_sql = await self.ds.permission_allowed(
request.actor, "execute-sql", database
)
async def extra_template():
display_rows = []
truncate_cells = self.ds.setting("truncate_cells_html")
for row in results.rows if results else []:
display_row = []
for column, value in zip(results.columns, row):
display_value = value
# Let the plugins have a go
# pylint: disable=no-member
plugin_display_value = None
for candidate in pm.hook.render_cell(
row=row,
value=value,
column=column,
table=None,
database=database,
datasette=self.ds,
request=request,
):
candidate = await await_me_maybe(candidate)
if candidate is not None:
plugin_display_value = candidate
break
if plugin_display_value is not None:
display_value = plugin_display_value
else:
if value in ("", None):
display_value = markupsafe.Markup("&nbsp;")
elif is_url(str(display_value).strip()):
display_value = markupsafe.Markup(
'<a href="{url}">{truncated_url}</a>'.format(
url=markupsafe.escape(value.strip()),
truncated_url=markupsafe.escape(
truncate_url(value.strip(), truncate_cells)
),
)
)
elif isinstance(display_value, bytes):
blob_url = path_with_format(
request=request,
format="blob",
extra_qs={
"_blob_column": column,
"_blob_hash": hashlib.sha256(
display_value
).hexdigest(),
},
)
formatted = format_bytes(len(value))
display_value = markupsafe.Markup(
'<a class="blob-download" href="{}"{}>&lt;Binary:&nbsp;{:,}&nbsp;byte{}&gt;</a>'.format(
blob_url,
' title="{}"'.format(formatted)
if "bytes" not in formatted
else "",
len(value),
"" if len(value) == 1 else "s",
)
)
else:
display_value = str(value)
if truncate_cells and len(display_value) > truncate_cells:
display_value = (
display_value[:truncate_cells] + "\u2026"
)
display_row.append(display_value)
display_rows.append(display_row)
# Show 'Edit SQL' button only if:
# - User is allowed to execute SQL
# - SQL is an approved SELECT statement
# - No magic parameters, so no :_ in the SQL string
edit_sql_url = None
is_validated_sql = False
try:
validate_sql_select(sql)
is_validated_sql = True
except InvalidSql:
pass
if allow_execute_sql and is_validated_sql and ":_" not in sql:
edit_sql_url = (
self.ds.urls.database(database)
+ "?"
+ urlencode(
{
**{
"sql": sql,
},
**named_parameter_values,
}
)
)
show_hide_hidden = ""
if metadata.get("hide_sql"):
if bool(params.get("_show_sql")): if bool(params.get("_show_sql")):
show_hide_link = path_with_removed_args(request, {"_show_sql"}) show_hide_link = path_with_removed_args(request, {"_show_sql"})
show_hide_text = "hide" show_hide_text = "hide"
@ -855,42 +639,86 @@ class QueryView(DataView):
show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) show_hide_link = path_with_added_args(request, {"_hide_sql": 1})
show_hide_text = "hide" show_hide_text = "hide"
hide_sql = show_hide_text == "show" hide_sql = show_hide_text == "show"
return {
"display_rows": display_rows, # Show 'Edit SQL' button only if:
"custom_sql": True, # - User is allowed to execute SQL
"named_parameter_values": named_parameter_values, # - SQL is an approved SELECT statement
"editable": editable, # - No magic parameters, so no :_ in the SQL string
"canned_query": canned_query, edit_sql_url = None
"edit_sql_url": edit_sql_url, is_validated_sql = False
"metadata": metadata, try:
"settings": self.ds.settings_dict(), validate_sql_select(sql)
"request": request, is_validated_sql = True
"show_hide_link": self.ds.urls.path(show_hide_link), except InvalidSql:
"show_hide_text": show_hide_text, pass
"show_hide_hidden": markupsafe.Markup(show_hide_hidden), if allow_execute_sql and is_validated_sql and ":_" not in sql:
"hide_sql": hide_sql, edit_sql_url = (
"table_columns": await _table_columns(self.ds, database) datasette.urls.database(database)
+ "?"
+ urlencode(
{
**{
"sql": sql,
},
**named_parameter_values,
}
)
)
r = Response.html(
await datasette.render_template(
template,
QueryContext(
database=database,
query={
"sql": sql,
"params": params,
},
canned_query=canned_query["name"] if canned_query else None,
private=private,
canned_query_write=canned_query_write,
db_is_immutable=not db.is_mutable,
error=query_error,
hide_sql=hide_sql,
show_hide_link=datasette.urls.path(show_hide_link),
show_hide_text=show_hide_text,
editable=not canned_query,
allow_execute_sql=allow_execute_sql,
tables=await get_tables(datasette, request, db),
named_parameter_values=named_parameter_values,
edit_sql_url=edit_sql_url,
display_rows=await display_rows(
datasette, database, request, rows, columns
),
table_columns=await _table_columns(datasette, database)
if allow_execute_sql if allow_execute_sql
else {}, else {},
} columns=columns,
renderers=renderers,
return ( url_csv=datasette.urls.path(
{ path_with_format(
"ok": not query_error, request=request, format="csv", extra_qs={"_size": "max"}
"database": database,
"query_name": canned_query,
"rows": results.rows if results else [],
"truncated": results.truncated if results else False,
"columns": columns,
"query": {"sql": sql, "params": params},
"error": str(query_error) if query_error else None,
"private": private,
"allow_execute_sql": allow_execute_sql,
},
extra_template,
templates,
400 if query_error else 200,
) )
),
show_hide_hidden=markupsafe.Markup(show_hide_hidden),
metadata=canned_query or metadata,
database_color=lambda _: "#ff0000",
alternate_url_json=alternate_url_json,
select_templates=[
f"{'*' if template_name == template.name else ''}{template_name}"
for template_name in templates
],
),
request=request,
view_name="database",
),
headers=headers,
)
else:
assert False, "Invalid format: {}".format(format_)
if datasette.cors:
add_cors_headers(r.headers)
return r
class MagicParameters(dict): class MagicParameters(dict):

View file

@ -9,7 +9,6 @@ import markupsafe
from datasette.plugins import pm from datasette.plugins import pm
from datasette.database import QueryInterrupted from datasette.database import QueryInterrupted
from datasette import tracer from datasette import tracer
from datasette.renderer import json_renderer
from datasette.utils import ( from datasette.utils import (
add_cors_headers, add_cors_headers,
await_me_maybe, await_me_maybe,
@ -21,7 +20,6 @@ from datasette.utils import (
tilde_encode, tilde_encode,
escape_sqlite, escape_sqlite,
filters_should_redirect, filters_should_redirect,
format_bytes,
is_url, is_url,
path_from_row_pks, path_from_row_pks,
path_with_added_args, path_with_added_args,
@ -38,7 +36,7 @@ from datasette.utils import (
from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response
from datasette.filters import Filters from datasette.filters import Filters
import sqlite_utils import sqlite_utils
from .base import BaseView, DataView, DatasetteError, ureg, _error, stream_csv from .base import BaseView, DatasetteError, ureg, _error, stream_csv
from .database import QueryView from .database import QueryView
LINK_WITH_LABEL = ( LINK_WITH_LABEL = (
@ -698,57 +696,6 @@ async def table_view(datasette, request):
return response return response
class CannedQueryView(DataView):
def __init__(self, datasette):
self.ds = datasette
async def post(self, request):
from datasette.app import TableNotFound
try:
await self.ds.resolve_table(request)
except TableNotFound as e:
# Was this actually a canned query?
canned_query = await self.ds.get_canned_query(
e.database_name, e.table, request.actor
)
if canned_query:
# Handle POST to a canned query
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=e.table,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
return Response.text("Method not allowed", status=405)
async def data(self, request, **kwargs):
from datasette.app import TableNotFound
try:
await self.ds.resolve_table(request)
except TableNotFound as not_found:
canned_query = await self.ds.get_canned_query(
not_found.database_name, not_found.table, request.actor
)
if canned_query:
return await QueryView(self.ds).data(
request,
canned_query["sql"],
metadata=canned_query,
editable=False,
canned_query=not_found.table,
named_parameters=canned_query.get("params"),
write=bool(canned_query.get("write")),
)
else:
raise
async def table_view_traced(datasette, request): async def table_view_traced(datasette, request):
from datasette.app import TableNotFound from datasette.app import TableNotFound
@ -761,10 +708,7 @@ async def table_view_traced(datasette, request):
) )
# If this is a canned query, not a table, then dispatch to QueryView instead # If this is a canned query, not a table, then dispatch to QueryView instead
if canned_query: if canned_query:
if request.method == "POST": return await QueryView()(request, datasette)
return await CannedQueryView(datasette).post(request)
else:
return await CannedQueryView(datasette).get(request)
else: else:
raise raise

View file

@ -95,12 +95,12 @@ def test_insert(canned_write_client):
csrftoken_from=True, csrftoken_from=True,
cookies={"foo": "bar"}, cookies={"foo": "bar"},
) )
assert 302 == response.status
assert "/data/add_name?success" == response.headers["Location"]
messages = canned_write_client.ds.unsign( messages = canned_write_client.ds.unsign(
response.cookies["ds_messages"], "messages" response.cookies["ds_messages"], "messages"
) )
assert [["Query executed, 1 row affected", 1]] == messages assert messages == [["Query executed, 1 row affected", 1]]
assert response.status == 302
assert response.headers["Location"] == "/data/add_name?success"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -382,11 +382,11 @@ def test_magic_parameters_cannot_be_used_in_arbitrary_queries(magic_parameters_c
def test_canned_write_custom_template(canned_write_client): def test_canned_write_custom_template(canned_write_client):
response = canned_write_client.get("/data/update_name") response = canned_write_client.get("/data/update_name")
assert response.status == 200 assert response.status == 200
assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text
assert ( assert (
"<!-- Templates considered: *query-data-update_name.html, query-data.html, query.html -->" "<!-- Templates considered: *query-data-update_name.html, query-data.html, query.html -->"
in response.text in response.text
) )
assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text
# And test for link rel=alternate while we're here: # And test for link rel=alternate while we're here:
assert ( assert (
'<link rel="alternate" type="application/json+datasette" href="http://localhost/data/update_name.json">' '<link rel="alternate" type="application/json+datasette" href="http://localhost/data/update_name.json">'