Work in progress on query view, refs #2049

This commit is contained in:
Simon Willison 2023-04-26 20:47:03 -07:00
commit 026429fadd
3 changed files with 426 additions and 5 deletions

View file

@ -1,3 +1,4 @@
from asyncinject import Registry
import os import os
import hashlib import hashlib
import itertools import itertools
@ -877,6 +878,416 @@ async def database_index_view(request, datasette, db):
) )
async def query_view(
request,
datasette,
canned_query=None,
_size=None,
named_parameters=None,
write=False,
):
print("query_view")
db = await datasette.resolve_database(request)
database = db.name
# TODO: Why do I do this? Is it to eliminate multi-args?
# It's going to break ?_extra=...&_extra=...
params = {key: request.args.get(key) for key in request.args}
sql = ""
if "sql" in params:
sql = params.pop("sql")
# TODO: Behave differently for canned query here:
await datasette.ensure_permissions(request.actor, [("execute-sql", database)])
_shape = None
if "_shape" in params:
_shape = params.pop("_shape")
# ?_shape=arrays - "rows" is the default option, shown above
# ?_shape=objects - "rows" is a list of JSON key/value objects
# ?_shape=array - an JSON array of objects
# ?_shape=array&_nl=on - a newline-separated list of JSON objects
# ?_shape=arrayfirst - a flat JSON array containing just the first value from each row
# ?_shape=object - a JSON object keyed using the primary keys of the rows
async def _results(_sql, _params):
# Returns (results, error (can be None))
try:
return await db.execute(_sql, _params, truncate=True), None
except Exception as e:
return None, e
async def shape_arrays(_results):
results, error = _results
if error:
return {"ok": False, "error": str(error)}
return {
"ok": True,
"rows": [list(r) for r in results.rows],
"truncated": results.truncated,
}
async def shape_objects(_results):
results, error = _results
if error:
return {"ok": False, "error": str(error)}
return {
"ok": True,
"rows": [dict(r) for r in results.rows],
"truncated": results.truncated,
}
async def shape_array(_results):
results, error = _results
if error:
return {"ok": False, "error": str(error)}
return [dict(r) for r in results.rows]
shape_fn = {
"arrays": shape_arrays,
"objects": shape_objects,
"array": shape_array,
# "arrayfirst": shape_arrayfirst,
# "object": shape_object,
}[_shape or "objects"]
registry = Registry.from_dict(
{
"_results": _results,
"_shape": shape_fn,
},
parallel=False,
)
results = await registry.resolve_multi(
["_shape"],
results={
"_sql": sql,
"_params": params,
},
)
# If "shape" does not include "rows" we return that as the response
if "rows" not in results["_shape"]:
return Response.json(results["_shape"])
output = results["_shape"]
output.update(dict((k, v) for k, v in results.items() if not k.startswith("_")))
response = Response.json(output)
assert False
import pdb
pdb.set_trace()
if isinstance(output, dict) and output.get("ok") is False:
# TODO: Other error codes?
response.status_code = 400
if datasette.cors:
add_cors_headers(response.headers)
return response
# registry = Registry(
# extra_count,
# extra_facet_results,
# extra_facets_timed_out,
# extra_suggested_facets,
# facet_instances,
# extra_human_description_en,
# extra_next_url,
# extra_columns,
# extra_primary_keys,
# run_display_columns_and_rows,
# extra_display_columns,
# extra_display_rows,
# extra_debug,
# extra_request,
# extra_query,
# extra_metadata,
# extra_extras,
# extra_database,
# extra_table,
# extra_database_color,
# extra_table_actions,
# extra_filters,
# extra_renderers,
# extra_custom_table_templates,
# extra_sorted_facet_results,
# extra_table_definition,
# extra_view_definition,
# extra_is_view,
# extra_private,
# extra_expandable_columns,
# extra_form_hidden_args,
# )
results = await registry.resolve_multi(
["extra_{}".format(extra) for extra in extras]
)
data = {
"ok": True,
"next": next_value and str(next_value) or None,
}
data.update(
{
key.replace("extra_", ""): value
for key, value in results.items()
if key.startswith("extra_") and key.replace("extra_", "") in extras
}
)
raw_sqlite_rows = rows[:page_size]
data["rows"] = [dict(r) for r in raw_sqlite_rows]
private = False
if canned_query:
# Respect canned query permissions
visible, private = await datasette.check_visibility(
request.actor,
permissions=[
("view-query", (database, canned_query)),
("view-database", database),
"view-instance",
],
)
if not visible:
raise Forbidden("You do not have permission to view this query")
else:
await datasette.ensure_permissions(request.actor, [("execute-sql", database)])
# If there's no sql, show the database index page
if not sql:
return await database_index_view(request, datasette, db)
validate_sql_select(sql)
# Extract any :named parameters
named_parameters = named_parameters or await derive_named_parameters(db, sql)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters
if not named_parameter.startswith("_")
}
# Set to blank string if missing from params
for named_parameter in named_parameters:
if named_parameter not in params and not named_parameter.startswith("_"):
params[named_parameter] = ""
extra_args = {}
if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"])
if _size:
extra_args["page_size"] = _size
templates = [f"query-{to_css_class(database)}.html", "query.html"]
if canned_query:
templates.insert(
0,
f"query-{to_css_class(database)}-{to_css_class(canned_query)}.html",
)
query_error = None
# Execute query - as write or as read
if write:
raise NotImplementedError("Write queries not yet implemented")
# if request.method == "POST":
# # If database is immutable, return an error
# if not db.is_mutable:
# raise Forbidden("Database is immutable")
# body = await request.post_body()
# body = body.decode("utf-8").strip()
# if body.startswith("{") and body.endswith("}"):
# params = json.loads(body)
# # But we want key=value strings
# for key, value in params.items():
# params[key] = str(value)
# else:
# params = dict(parse_qsl(body, keep_blank_values=True))
# # Should we return JSON?
# should_return_json = (
# request.headers.get("accept") == "application/json"
# or request.args.get("_json")
# or params.get("_json")
# )
# if canned_query:
# params_for_query = MagicParameters(params, request, self.ds)
# else:
# params_for_query = params
# ok = None
# try:
# cursor = await self.ds.databases[database].execute_write(
# sql, params_for_query
# )
# message = metadata.get(
# "on_success_message"
# ) or "Query executed, {} row{} affected".format(
# cursor.rowcount, "" if cursor.rowcount == 1 else "s"
# )
# message_type = self.ds.INFO
# redirect_url = metadata.get("on_success_redirect")
# ok = True
# except Exception as e:
# message = metadata.get("on_error_message") or str(e)
# message_type = self.ds.ERROR
# redirect_url = metadata.get("on_error_redirect")
# ok = False
# if should_return_json:
# return Response.json(
# {
# "ok": ok,
# "message": message,
# "redirect": redirect_url,
# }
# )
# else:
# self.ds.add_message(request, message, message_type)
# return self.redirect(request, redirect_url or request.path)
# else:
# async def extra_template():
# return {
# "request": request,
# "db_is_immutable": not db.is_mutable,
# "path_with_added_args": path_with_added_args,
# "path_with_removed_args": path_with_removed_args,
# "named_parameter_values": named_parameter_values,
# "canned_query": canned_query,
# "success_message": request.args.get("_success") or "",
# "canned_write": True,
# }
# return (
# {
# "database": database,
# "rows": [],
# "truncated": False,
# "columns": [],
# "query": {"sql": sql, "params": params},
# "private": private,
# },
# extra_template,
# templates,
# )
# Not a write
rows = []
if canned_query:
params_for_query = MagicParameters(params, request, datasette)
else:
params_for_query = params
try:
results = await datasette.execute(
database, sql, params_for_query, truncate=True, **extra_args
)
columns = [r[0] for r in results.description]
rows = list(results.rows)
except sqlite3.DatabaseError as e:
query_error = e
results = None
columns = []
allow_execute_sql = await datasette.permission_allowed(
request.actor, "execute-sql", database
)
format_ = request.url_vars.get("format") or "html"
if format_ == "csv":
raise NotImplementedError("CSV format not yet implemented")
elif format_ in datasette.renderers.keys():
# Dispatch request to the correct output format renderer
# (CSV is not handled here due to streaming)
result = call_with_supported_arguments(
datasette.renderers[format_][0],
datasette=datasette,
columns=columns,
rows=rows,
sql=sql,
query_name=None,
database=db.name,
table=None,
request=request,
view_name="table", # TODO: should this be "query"?
# These will be deprecated in Datasette 1.0:
args=request.args,
data={
"rows": rows,
}, # TODO what should this be?
)
result = await await_me_maybe(result)
if result is None:
raise NotFound("No data")
if isinstance(result, dict):
r = Response(
body=result.get("body"),
status=result.get("status_code") or 200,
content_type=result.get("content_type", "text/plain"),
headers=result.get("headers"),
)
elif isinstance(result, Response):
r = result
# if status_code is not None:
# # Over-ride the status code
# r.status = status_code
else:
assert False, f"{result} should be dict or Response"
elif format_ == "html":
headers = {}
templates = [f"query-{to_css_class(database)}.html", "query.html"]
template = datasette.jinja_env.select_template(templates)
alternate_url_json = datasette.absolute_url(
request,
datasette.urls.path(path_with_format(request=request, format="json")),
)
headers.update(
{
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format(
alternate_url_json
)
}
)
r = Response.html(
await datasette.render_template(
template,
dict(
data,
append_querystring=append_querystring,
path_with_replaced_args=path_with_replaced_args,
fix_path=datasette.urls.path,
settings=datasette.settings_dict(),
# TODO: review up all of these hacks:
alternate_url_json=alternate_url_json,
datasette_allow_facet=(
"true" if datasette.setting("allow_facet") else "false"
),
is_sortable=any(c["sortable"] for c in data["display_columns"]),
allow_execute_sql=await datasette.permission_allowed(
request.actor, "execute-sql", resolved.db.name
),
query_ms=1.2,
select_templates=[
f"{'*' if template_name == template.name else ''}{template_name}"
for template_name in templates
],
),
request=request,
view_name="table",
),
headers=headers,
)
else:
assert False, "Invalid format: {}".format(format_)
# if next_url:
# r.headers["link"] = f'<{next_url}>; rel="next"'
return r
async def database_view_impl( async def database_view_impl(
request, request,
datasette, datasette,
@ -887,10 +1298,19 @@ async def database_view_impl(
): ):
db = await datasette.resolve_database(request) db = await datasette.resolve_database(request)
database = db.name database = db.name
if request.args.get("sql", "").strip():
return await query_view(
request, datasette, canned_query, _size, named_parameters, write
)
# Index page shows the tables/views/canned queries for this database
params = {key: request.args.get(key) for key in request.args} params = {key: request.args.get(key) for key in request.args}
sql = "" sql = ""
if "sql" in params: if "sql" in params:
sql = params.pop("sql") sql = params.pop("sql")
_shape = None _shape = None
if "_shape" in params: if "_shape" in params:
_shape = params.pop("_shape") _shape = params.pop("_shape")

View file

@ -58,7 +58,7 @@ setup(
"mergedeep>=1.1.1", "mergedeep>=1.1.1",
"itsdangerous>=1.1", "itsdangerous>=1.1",
"sqlite-utils>=3.30", "sqlite-utils>=3.30",
"asyncinject>=0.5", "asyncinject>=0.6",
], ],
entry_points=""" entry_points="""
[console_scripts] [console_scripts]

View file

@ -1,6 +1,7 @@
from datasette.cli import cli, serve from datasette.cli import cli, serve
from datasette.plugins import pm from datasette.plugins import pm
from click.testing import CliRunner from click.testing import CliRunner
from unittest.mock import ANY
import textwrap import textwrap
import json import json
@ -35,11 +36,11 @@ def test_serve_with_get(tmp_path_factory):
], ],
) )
assert 0 == result.exit_code, result.output assert 0 == result.exit_code, result.output
assert { assert json.loads(result.output) == {
"database": "_memory", "ok": True,
"rows": [{"sqlite_version()": ANY}],
"truncated": False, "truncated": False,
"columns": ["sqlite_version()"], }
}.items() <= json.loads(result.output).items()
# The plugin should have created hello.txt # The plugin should have created hello.txt
assert (plugins_dir / "hello.txt").read_text() == "hello" assert (plugins_dir / "hello.txt").read_text() == "hello"