From 026429fadd7a1f4f85ebfda1bbfe882938f455f2 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 26 Apr 2023 20:47:03 -0700 Subject: [PATCH] Work in progress on query view, refs #2049 --- datasette/views/database.py | 420 ++++++++++++++++++++++++++++++++++++ setup.py | 2 +- tests/test_cli_serve_get.py | 9 +- 3 files changed, 426 insertions(+), 5 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index d097c933..33b6702b 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,3 +1,4 @@ +from asyncinject import Registry import os import hashlib import itertools @@ -877,6 +878,416 @@ async def database_index_view(request, datasette, db): ) +async def query_view( + request, + datasette, + canned_query=None, + _size=None, + named_parameters=None, + write=False, +): + print("query_view") + db = await datasette.resolve_database(request) + database = db.name + # TODO: Why do I do this? Is it to eliminate multi-args? + # It's going to break ?_extra=...&_extra=... + params = {key: request.args.get(key) for key in request.args} + sql = "" + if "sql" in params: + sql = params.pop("sql") + + # TODO: Behave differently for canned query here: + await datasette.ensure_permissions(request.actor, [("execute-sql", database)]) + + _shape = None + if "_shape" in params: + _shape = params.pop("_shape") + + # ?_shape=arrays - "rows" is the default option, shown above + # ?_shape=objects - "rows" is a list of JSON key/value objects + # ?_shape=array - an JSON array of objects + # ?_shape=array&_nl=on - a newline-separated list of JSON objects + # ?_shape=arrayfirst - a flat JSON array containing just the first value from each row + # ?_shape=object - a JSON object keyed using the primary keys of the rows + async def _results(_sql, _params): + # Returns (results, error (can be None)) + try: + return await db.execute(_sql, _params, truncate=True), None + except Exception as e: + return None, e + + async def shape_arrays(_results): + results, error = _results + if error: + return {"ok": False, "error": str(error)} + return { + "ok": True, + "rows": [list(r) for r in results.rows], + "truncated": results.truncated, + } + + async def shape_objects(_results): + results, error = _results + if error: + return {"ok": False, "error": str(error)} + return { + "ok": True, + "rows": [dict(r) for r in results.rows], + "truncated": results.truncated, + } + + async def shape_array(_results): + results, error = _results + if error: + return {"ok": False, "error": str(error)} + return [dict(r) for r in results.rows] + + shape_fn = { + "arrays": shape_arrays, + "objects": shape_objects, + "array": shape_array, + # "arrayfirst": shape_arrayfirst, + # "object": shape_object, + }[_shape or "objects"] + + registry = Registry.from_dict( + { + "_results": _results, + "_shape": shape_fn, + }, + parallel=False, + ) + + results = await registry.resolve_multi( + ["_shape"], + results={ + "_sql": sql, + "_params": params, + }, + ) + + # If "shape" does not include "rows" we return that as the response + if "rows" not in results["_shape"]: + return Response.json(results["_shape"]) + + output = results["_shape"] + output.update(dict((k, v) for k, v in results.items() if not k.startswith("_"))) + + response = Response.json(output) + + assert False + + import pdb + + pdb.set_trace() + + if isinstance(output, dict) and output.get("ok") is False: + # TODO: Other error codes? + + response.status_code = 400 + + if datasette.cors: + add_cors_headers(response.headers) + + return response + + # registry = Registry( + # extra_count, + # extra_facet_results, + # extra_facets_timed_out, + # extra_suggested_facets, + # facet_instances, + # extra_human_description_en, + # extra_next_url, + # extra_columns, + # extra_primary_keys, + # run_display_columns_and_rows, + # extra_display_columns, + # extra_display_rows, + # extra_debug, + # extra_request, + # extra_query, + # extra_metadata, + # extra_extras, + # extra_database, + # extra_table, + # extra_database_color, + # extra_table_actions, + # extra_filters, + # extra_renderers, + # extra_custom_table_templates, + # extra_sorted_facet_results, + # extra_table_definition, + # extra_view_definition, + # extra_is_view, + # extra_private, + # extra_expandable_columns, + # extra_form_hidden_args, + # ) + + results = await registry.resolve_multi( + ["extra_{}".format(extra) for extra in extras] + ) + data = { + "ok": True, + "next": next_value and str(next_value) or None, + } + data.update( + { + key.replace("extra_", ""): value + for key, value in results.items() + if key.startswith("extra_") and key.replace("extra_", "") in extras + } + ) + raw_sqlite_rows = rows[:page_size] + data["rows"] = [dict(r) for r in raw_sqlite_rows] + + private = False + if canned_query: + # Respect canned query permissions + visible, private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-query", (database, canned_query)), + ("view-database", database), + "view-instance", + ], + ) + if not visible: + raise Forbidden("You do not have permission to view this query") + + else: + await datasette.ensure_permissions(request.actor, [("execute-sql", database)]) + + # If there's no sql, show the database index page + if not sql: + return await database_index_view(request, datasette, db) + + validate_sql_select(sql) + + # Extract any :named parameters + named_parameters = named_parameters or await derive_named_parameters(db, sql) + named_parameter_values = { + named_parameter: params.get(named_parameter) or "" + for named_parameter in named_parameters + if not named_parameter.startswith("_") + } + + # Set to blank string if missing from params + for named_parameter in named_parameters: + if named_parameter not in params and not named_parameter.startswith("_"): + params[named_parameter] = "" + + extra_args = {} + if params.get("_timelimit"): + extra_args["custom_time_limit"] = int(params["_timelimit"]) + if _size: + extra_args["page_size"] = _size + + templates = [f"query-{to_css_class(database)}.html", "query.html"] + if canned_query: + templates.insert( + 0, + f"query-{to_css_class(database)}-{to_css_class(canned_query)}.html", + ) + + query_error = None + + # Execute query - as write or as read + if write: + raise NotImplementedError("Write queries not yet implemented") + # if request.method == "POST": + # # If database is immutable, return an error + # if not db.is_mutable: + # raise Forbidden("Database is immutable") + # body = await request.post_body() + # body = body.decode("utf-8").strip() + # if body.startswith("{") and body.endswith("}"): + # params = json.loads(body) + # # But we want key=value strings + # for key, value in params.items(): + # params[key] = str(value) + # else: + # params = dict(parse_qsl(body, keep_blank_values=True)) + # # Should we return JSON? + # should_return_json = ( + # request.headers.get("accept") == "application/json" + # or request.args.get("_json") + # or params.get("_json") + # ) + # if canned_query: + # params_for_query = MagicParameters(params, request, self.ds) + # else: + # params_for_query = params + # ok = None + # try: + # cursor = await self.ds.databases[database].execute_write( + # sql, params_for_query + # ) + # message = metadata.get( + # "on_success_message" + # ) or "Query executed, {} row{} affected".format( + # cursor.rowcount, "" if cursor.rowcount == 1 else "s" + # ) + # message_type = self.ds.INFO + # redirect_url = metadata.get("on_success_redirect") + # ok = True + # except Exception as e: + # message = metadata.get("on_error_message") or str(e) + # message_type = self.ds.ERROR + # redirect_url = metadata.get("on_error_redirect") + # ok = False + # if should_return_json: + # return Response.json( + # { + # "ok": ok, + # "message": message, + # "redirect": redirect_url, + # } + # ) + # else: + # self.ds.add_message(request, message, message_type) + # return self.redirect(request, redirect_url or request.path) + # else: + + # async def extra_template(): + # return { + # "request": request, + # "db_is_immutable": not db.is_mutable, + # "path_with_added_args": path_with_added_args, + # "path_with_removed_args": path_with_removed_args, + # "named_parameter_values": named_parameter_values, + # "canned_query": canned_query, + # "success_message": request.args.get("_success") or "", + # "canned_write": True, + # } + + # return ( + # { + # "database": database, + # "rows": [], + # "truncated": False, + # "columns": [], + # "query": {"sql": sql, "params": params}, + # "private": private, + # }, + # extra_template, + # templates, + # ) + + # Not a write + rows = [] + if canned_query: + params_for_query = MagicParameters(params, request, datasette) + else: + params_for_query = params + try: + results = await datasette.execute( + database, sql, params_for_query, truncate=True, **extra_args + ) + columns = [r[0] for r in results.description] + rows = list(results.rows) + except sqlite3.DatabaseError as e: + query_error = e + results = None + columns = [] + + allow_execute_sql = await datasette.permission_allowed( + request.actor, "execute-sql", database + ) + + format_ = request.url_vars.get("format") or "html" + + if format_ == "csv": + raise NotImplementedError("CSV format not yet implemented") + elif format_ in datasette.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + datasette.renderers[format_][0], + datasette=datasette, + columns=columns, + rows=rows, + sql=sql, + query_name=None, + database=db.name, + table=None, + request=request, + view_name="table", # TODO: should this be "query"? + # These will be deprecated in Datasette 1.0: + args=request.args, + data={ + "rows": rows, + }, # TODO what should this be? + ) + result = await await_me_maybe(result) + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + r = Response( + body=result.get("body"), + status=result.get("status_code") or 200, + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + r = result + # if status_code is not None: + # # Over-ride the status code + # r.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + headers = {} + templates = [f"query-{to_css_class(database)}.html", "query.html"] + template = datasette.jinja_env.select_template(templates) + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + headers.update( + { + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + } + ) + r = Response.html( + await datasette.render_template( + template, + dict( + data, + append_querystring=append_querystring, + path_with_replaced_args=path_with_replaced_args, + fix_path=datasette.urls.path, + settings=datasette.settings_dict(), + # TODO: review up all of these hacks: + alternate_url_json=alternate_url_json, + datasette_allow_facet=( + "true" if datasette.setting("allow_facet") else "false" + ), + is_sortable=any(c["sortable"] for c in data["display_columns"]), + allow_execute_sql=await datasette.permission_allowed( + request.actor, "execute-sql", resolved.db.name + ), + query_ms=1.2, + select_templates=[ + f"{'*' if template_name == template.name else ''}{template_name}" + for template_name in templates + ], + ), + request=request, + view_name="table", + ), + headers=headers, + ) + else: + assert False, "Invalid format: {}".format(format_) + # if next_url: + # r.headers["link"] = f'<{next_url}>; rel="next"' + return r + + async def database_view_impl( request, datasette, @@ -887,10 +1298,19 @@ async def database_view_impl( ): db = await datasette.resolve_database(request) database = db.name + + if request.args.get("sql", "").strip(): + return await query_view( + request, datasette, canned_query, _size, named_parameters, write + ) + + # Index page shows the tables/views/canned queries for this database + params = {key: request.args.get(key) for key in request.args} sql = "" if "sql" in params: sql = params.pop("sql") + _shape = None if "_shape" in params: _shape = params.pop("_shape") diff --git a/setup.py b/setup.py index d41e428a..b591869e 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ setup( "mergedeep>=1.1.1", "itsdangerous>=1.1", "sqlite-utils>=3.30", - "asyncinject>=0.5", + "asyncinject>=0.6", ], entry_points=""" [console_scripts] diff --git a/tests/test_cli_serve_get.py b/tests/test_cli_serve_get.py index ac44e1e2..e484a6db 100644 --- a/tests/test_cli_serve_get.py +++ b/tests/test_cli_serve_get.py @@ -1,6 +1,7 @@ from datasette.cli import cli, serve from datasette.plugins import pm from click.testing import CliRunner +from unittest.mock import ANY import textwrap import json @@ -35,11 +36,11 @@ def test_serve_with_get(tmp_path_factory): ], ) assert 0 == result.exit_code, result.output - assert { - "database": "_memory", + assert json.loads(result.output) == { + "ok": True, + "rows": [{"sqlite_version()": ANY}], "truncated": False, - "columns": ["sqlite_version()"], - }.items() <= json.loads(result.output).items() + } # The plugin should have created hello.txt assert (plugins_dir / "hello.txt").read_text() == "hello"