diff --git a/datasette/app.py b/datasette/app.py index dd54446a..96683895 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -47,10 +47,11 @@ from .views import Context from .views.database import ( database_download, DatabaseView, - ExecuteWriteAnalyzeView, - ExecuteWriteView, TableCreateView, QueryView, +) +from .views.execute_write import ExecuteWriteAnalyzeView, ExecuteWriteView +from .views.stored_queries import ( QueryCreateAnalyzeView, QueryDeleteView, QueryDefinitionView, diff --git a/datasette/views/database.py b/datasette/views/database.py index 28a3b579..f11a7d16 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -12,16 +12,14 @@ import textwrap from datasette.events import AlterTableEvent, CreateTableEvent, InsertRowsEvent from datasette.database import QueryInterrupted -from datasette.resources import DatabaseResource, QueryResource, TableResource +from datasette.resources import DatabaseResource, QueryResource from datasette.utils import ( add_cors_headers, await_me_maybe, call_with_supported_arguments, named_parameters as derive_named_parameters, - escape_sqlite, format_bytes, make_slot_function, - path_from_row_pks, tilde_decode, to_css_class, validate_sql_select, @@ -37,6 +35,7 @@ from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden from datasette.plugins import pm from .base import BaseView, DatasetteError, View, _error, stream_csv +from .query_helpers import _ensure_stored_query_execution_permissions, _table_columns from . import Context @@ -425,1220 +424,6 @@ async def database_download(request, datasette): ) -_query_name_re = re.compile(r"^[^/\.\n]+$") - -_query_fields = { - "sql", - "title", - "description", - "description_html", - "hide_sql", - "fragment", - "parameters", - "params", - "is_private", - "on_success_message", - "on_success_message_sql", - "on_success_redirect", - "on_error_message", - "on_error_redirect", -} - -_query_create_fields = _query_fields | {"name", "mode", "csrftoken"} -_query_update_fields = _query_fields -_query_write_fields = { - "on_success_message", - "on_success_message_sql", - "on_success_redirect", - "on_error_message", - "on_error_redirect", -} - - -class QueryValidationError(Exception): - def __init__(self, message, status=400): - self.message = message - self.status = status - - -def _actor_id(actor): - if isinstance(actor, dict): - return actor.get("id") - return None - - -def _as_bool(value): - if isinstance(value, bool): - return value - if value is None: - return False - if isinstance(value, int): - return bool(value) - if isinstance(value, str): - return value.lower() in {"1", "true", "t", "yes", "on"} - return bool(value) - - -def _as_optional_bool(value, name): - if value is None or value == "": - return None - if isinstance(value, bool): - return value - if isinstance(value, int): - return bool(value) - if isinstance(value, str): - lowered = value.lower() - if lowered in {"1", "true", "t", "yes", "on"}: - return True - if lowered in {"0", "false", "f", "no", "off"}: - return False - raise QueryValidationError("{} must be 0 or 1".format(name)) - - -def _query_list_limit(value, default=50): - if value in (None, ""): - return default - try: - return min(max(1, int(value)), 1000) - except ValueError as ex: - raise QueryValidationError("_size must be an integer") from ex - - -def _derived_query_parameters(sql): - parameters = [] - seen = set() - for parameter in derive_named_parameters(sql): - if parameter.startswith("_"): - raise QueryValidationError("Magic parameters are not allowed") - if parameter not in seen: - parameters.append(parameter) - seen.add(parameter) - return parameters - - -def _coerce_query_parameters(value, derived): - if value is None: - return derived - if isinstance(value, str): - parameters = [ - parameter.strip() - for parameter in re.split(r"[\s,]+", value) - if parameter.strip() - ] - elif isinstance(value, list): - parameters = value - else: - raise QueryValidationError("parameters must be a list of strings") - if not all(isinstance(parameter, str) for parameter in parameters): - raise QueryValidationError("parameters must be a list of strings") - if any(parameter.startswith("_") for parameter in parameters): - raise QueryValidationError("Magic parameters are not allowed") - if set(parameters) != set(derived): - raise QueryValidationError("parameters must match SQL named parameters") - return parameters - - -def _analysis_is_write(analysis): - return any( - access.operation in {"insert", "update", "delete"} - for access in analysis.table_accesses - ) - - -def _block_framing(response): - response.headers["Content-Security-Policy"] = "frame-ancestors 'none'" - response.headers["X-Frame-Options"] = "DENY" - return response - - -def _wants_json(request, is_json, data): - return ( - is_json - or request.headers.get("accept") == "application/json" - or (isinstance(data, dict) and data.get("_json")) - ) - - -def _query_create_form_error_message(message): - return { - "Query name is required": "URL is required", - "Invalid query name": "Invalid URL", - "Query name conflicts with a table or view": ( - "URL conflicts with an existing table or view" - ), - "Query already exists": "A query already exists at that URL", - }.get(message, message) - - -async def _json_or_form_payload(request): - content_type = request.headers.get("content-type", "") - if content_type.startswith("application/json"): - body = await request.post_body() - try: - return json.loads(body or b"{}"), True - except json.JSONDecodeError as e: - raise QueryValidationError("Invalid JSON: {}".format(e)) - return await request.post_vars(), False - - -async def _check_query_name(db, name, *, existing=False): - if not name or not isinstance(name, str): - raise QueryValidationError("Query name is required") - if not _query_name_re.match(name): - raise QueryValidationError("Invalid query name") - if not existing and (await db.table_exists(name) or await db.view_exists(name)): - raise QueryValidationError("Query name conflicts with a table or view") - - -async def _analyze_user_query(datasette, db, sql, *, actor): - if not sql or not isinstance(sql, str): - raise QueryValidationError("SQL is required") - derived = _derived_query_parameters(sql) - params = {parameter: "" for parameter in derived} - try: - analysis = await db.analyze_sql(sql, params) - except sqlite3.DatabaseError as ex: - raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex - - is_write = _analysis_is_write(analysis) - if is_write: - try: - await datasette.ensure_query_write_permissions( - db.name, sql, actor=actor, analysis=analysis - ) - except Forbidden as ex: - raise QueryValidationError(str(ex), status=403) from ex - else: - try: - validate_sql_select(sql) - except InvalidSql as ex: - raise QueryValidationError(str(ex)) from ex - return is_write, derived, analysis - - -def _analysis_rows(analysis): - write_actions = { - "insert": "insert-row", - "update": "update-row", - "delete": "delete-row", - } - return [ - { - "operation": access.operation, - "database": access.database, - "table": access.table, - "required_permission": write_actions.get(access.operation, ""), - "source": access.source, - } - for access in analysis.table_accesses - ] - - -async def _analysis_rows_with_permissions(datasette, analysis, actor): - rows = _analysis_rows(analysis) - for row in rows: - permission = row["required_permission"] - if permission: - row["allowed"] = await datasette.allowed( - action=permission, - resource=TableResource(row["database"], row["table"]), - actor=actor, - ) - else: - row["allowed"] = None - return rows - - -def _coerce_execute_write_payload(data, is_json): - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - if is_json: - invalid_keys = set(data) - {"sql", "params"} - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(sorted(invalid_keys))) - ) - params = data.get("params") or {} - else: - params = { - key: value - for key, value in data.items() - if key not in {"sql", "csrftoken", "_json"} - } - if not isinstance(params, dict): - raise QueryValidationError("params must be a dictionary") - return data.get("sql"), params - - -async def _prepare_execute_write(datasette, db, sql, params, actor): - if not sql or not isinstance(sql, str): - raise QueryValidationError("SQL is required") - parameter_names = _derived_query_parameters(sql) - extra_params = set(params) - set(parameter_names) - if extra_params: - raise QueryValidationError( - "Unknown parameters: {}".format(", ".join(sorted(extra_params))) - ) - params = {name: params.get(name, "") for name in parameter_names} - try: - analysis = await db.analyze_sql(sql, params) - except sqlite3.DatabaseError as ex: - raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex - if not _analysis_is_write(analysis): - raise QueryValidationError( - "Use /-/query for read-only SQL; this endpoint only executes writes" - ) - try: - await datasette.ensure_query_write_permissions( - db.name, sql, actor=actor, analysis=analysis - ) - except Forbidden as ex: - raise QueryValidationError(str(ex), status=403) from ex - return parameter_names, params, analysis - - -async def _ensure_stored_query_execution_permissions(datasette, db, query, actor): - if query.get("is_trusted"): - return - if query.get("write"): - await datasette.ensure_permission( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=actor, - ) - await datasette.ensure_query_write_permissions( - db.name, query["sql"], actor=actor - ) - else: - await datasette.ensure_permission( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=actor, - ) - - -async def _execute_write_analysis_data(datasette, db, sql, actor): - parameter_names = [] - analysis_rows = [] - analysis_error = None - if sql: - try: - parameter_names = _derived_query_parameters(sql) - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - if _analysis_is_write(analysis): - analysis_rows = await _analysis_rows_with_permissions( - datasette, analysis, actor - ) - else: - analysis_error = ( - "Use /-/query for read-only SQL; " - "this endpoint only executes writes" - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - return { - "ok": analysis_error is None, - "parameters": parameter_names, - "analysis_error": analysis_error, - "analysis_rows": [row for row in analysis_rows if row["operation"] != "read"], - "execute_disabled": bool( - (not sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - } - - -async def _query_create_analysis_data(datasette, db, sql, actor): - has_sql = bool(sql and sql.strip()) - parameter_names = [] - analysis_rows = [] - analysis_error = None - if has_sql: - try: - parameter_names = _derived_query_parameters(sql) - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - analysis_rows = await _analysis_rows_with_permissions( - datasette, analysis, actor - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - return { - "ok": analysis_error is None, - "parameters": parameter_names, - "analysis_error": analysis_error, - "analysis_rows": analysis_rows, - "has_sql": has_sql, - "analysis_is_write": bool( - analysis_rows and any(row["required_permission"] for row in analysis_rows) - ), - "save_disabled": bool( - (not has_sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - } - - -async def _query_create_form_context( - datasette, - request, - db, - *, - sql="", - name="", - title="", - description="", - is_private=True, -): - analysis_data = await _query_create_analysis_data(datasette, db, sql, request.actor) - return { - "database": db.name, - "database_color": db.color, - "sql": sql, - "name": name, - "title": title, - "description": description, - "is_private": is_private, - **analysis_data, - } - - -async def _inserted_row_url(datasette, db, analysis, cursor): - if cursor.rowcount != 1: - return None - lastrowid = getattr(cursor, "lastrowid", None) - if lastrowid is None: - return None - direct_inserts = [ - access - for access in analysis.table_accesses - if access.operation == "insert" - and access.source is None - and access.database == db.name - ] - if len(direct_inserts) != 1: - return None - table = direct_inserts[0].table - pks = await db.primary_keys(table) - use_rowid = not pks - select = ( - "rowid" - if use_rowid - else ", ".join(escape_sqlite(primary_key) for primary_key in pks) - ) - try: - result = await db.execute( - "select {} from {} where rowid = ?".format(select, escape_sqlite(table)), - [lastrowid], - ) - except sqlite3.DatabaseError: - return None - row = result.first() - if row is None: - return None - row_path = path_from_row_pks(row, pks, use_rowid) - return datasette.urls.row(db.name, table, row_path) - - -def _apply_query_data_types(data): - typed = dict(data) - for key in ("hide_sql", "is_private"): - if key in typed: - typed[key] = _as_bool(typed[key]) - return typed - - -async def _prepare_query_create(datasette, request, db, data): - invalid_keys = set(data) - _query_create_fields - if invalid_keys: - raise QueryValidationError("Invalid keys: {}".format(", ".join(invalid_keys))) - - data = _apply_query_data_types(data) - name = data.get("name") - await _check_query_name(db, name) - if await datasette.get_query(db.name, name) is not None: - raise QueryValidationError("Query already exists") - - is_write, derived, analysis = await _analyze_user_query( - datasette, - db, - data.get("sql"), - actor=request.actor, - ) - if not is_write and any(data.get(field) for field in _query_write_fields): - raise QueryValidationError("Writable query fields require writable SQL") - - parameters = _coerce_query_parameters( - data.get("parameters", data.get("params")), - derived, - ) - return { - "name": name, - "sql": data["sql"], - "title": data.get("title"), - "description": data.get("description"), - "description_html": data.get("description_html"), - "hide_sql": _as_bool(data.get("hide_sql")), - "fragment": data.get("fragment"), - "parameters": parameters, - "is_write": is_write, - "is_private": _as_bool(data.get("is_private", True)), - "is_trusted": False, - "source": "user", - "owner_id": _actor_id(request.actor), - "on_success_message": data.get("on_success_message"), - "on_success_message_sql": data.get("on_success_message_sql"), - "on_success_redirect": data.get("on_success_redirect"), - "on_error_message": data.get("on_error_message"), - "on_error_redirect": data.get("on_error_redirect"), - "analysis": analysis, - } - - -async def _prepare_query_update(datasette, request, db, existing, update): - invalid_keys = set(update) - _query_update_fields - if invalid_keys: - raise QueryValidationError("Invalid keys: {}".format(", ".join(invalid_keys))) - - update = _apply_query_data_types(update) - sql = update.get("sql", existing["sql"]) - query_is_write = existing["is_write"] - derived = _derived_query_parameters(sql) - parameters = None - - if "sql" in update: - query_is_write, derived, _ = await _analyze_user_query( - datasette, - db, - sql, - actor=request.actor, - ) - - if "parameters" in update or "params" in update: - parameters = _coerce_query_parameters( - update.get("parameters", update.get("params")), - derived, - ) - elif "sql" in update: - parameters = derived - - if not query_is_write and any(update.get(field) for field in _query_write_fields): - raise QueryValidationError("Writable query fields require writable SQL") - - field_values = { - "sql": sql, - "title": update.get("title"), - "description": update.get("description"), - "description_html": update.get("description_html"), - "hide_sql": update.get("hide_sql"), - "fragment": update.get("fragment"), - "parameters": parameters, - "is_write": query_is_write, - "is_private": update.get("is_private"), - "on_success_message": update.get("on_success_message"), - "on_success_message_sql": update.get("on_success_message_sql"), - "on_success_redirect": update.get("on_success_redirect"), - "on_error_message": update.get("on_error_message"), - "on_error_redirect": update.get("on_error_redirect"), - } - update_kwargs = {} - for field_name, value in field_values.items(): - if field_name in update: - update_kwargs[field_name] = value - if parameters is not None: - update_kwargs["parameters"] = parameters - if "sql" in update: - update_kwargs["is_write"] = query_is_write - return update_kwargs - - -class ExecuteWriteView(BaseView): - name = "execute-write" - has_json_alternate = False - - async def _render_form( - self, - request, - db, - *, - sql="", - parameter_values=None, - analysis=None, - analysis_error=None, - execution_message=None, - execution_links=None, - execution_ok=None, - status=200, - ): - parameter_values = parameter_values or {} - execution_links = execution_links or [] - parameter_names = [] - analysis_rows = [] - table_columns = await _table_columns(self.ds, db.name) - hidden_table_names = set(await db.hidden_table_names()) - write_template_tables = { - table: columns - for table, columns in table_columns.items() - if columns and table not in hidden_table_names - } - if sql and analysis_error is None: - try: - parameter_names = _derived_query_parameters(sql) - if analysis is None: - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - if _analysis_is_write(analysis): - analysis_rows = await _analysis_rows_with_permissions( - self.ds, analysis, request.actor - ) - else: - analysis_error = ( - "Use /-/query for read-only SQL; " - "this endpoint only executes writes" - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - - allow_save_query = await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) and await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - save_query_base_url = None - save_query_url = None - if allow_save_query: - save_query_base_url = self.ds.urls.database(db.name) + "/-/queries/store" - if ( - sql - and analysis_error is None - and not any(row["allowed"] is False for row in analysis_rows) - ): - save_query_url = save_query_base_url + "?" + urlencode({"sql": sql}) - - response = await self.render( - ["execute_write.html"], - request, - { - "database": db.name, - "database_color": db.color, - "sql": sql, - "parameter_names": parameter_names, - "parameter_values": parameter_values, - "analysis_error": analysis_error, - "analysis_rows": [ - row for row in analysis_rows if row["operation"] != "read" - ], - "execution_message": execution_message, - "execution_links": execution_links, - "execution_ok": execution_ok, - "execute_disabled": bool( - (not sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - "table_columns": table_columns, - "write_template_tables": write_template_tables, - "save_query_url": save_query_url, - "save_query_base_url": save_query_base_url, - }, - ) - response.status = status - return _block_framing(response) - - async def get(self, request): - db = await self.ds.resolve_database(request) - await self.ds.ensure_permission( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - if not db.is_mutable: - return _block_framing( - _error( - ["Cannot execute write SQL because this database is immutable."], - 403, - ) - ) - return await self._render_form( - request, - db, - sql=request.args.get("sql") or "", - ) - - async def post(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing( - _error(["Permission denied: need execute-write-sql"], 403) - ) - if not db.is_mutable: - return _block_framing(_error(["Database is immutable"], 403)) - - data = {} - is_json = request.headers.get("content-type", "").startswith("application/json") - sql = "" - provided_params = {} - try: - data, is_json = await _json_or_form_payload(request) - sql, provided_params = _coerce_execute_write_payload(data, is_json) - parameter_names, params, analysis = await _prepare_execute_write( - self.ds, db, sql, provided_params, request.actor - ) - except QueryValidationError as ex: - if _wants_json(request, is_json, data): - return _block_framing(_error([ex.message], ex.status)) - return await self._render_form( - request, - db, - sql=sql or "", - parameter_values=provided_params, - analysis_error=ex.message, - execution_message=ex.message, - execution_ok=False, - status=ex.status, - ) - - try: - cursor = await db.execute_write(sql, params, request=request) - except sqlite3.DatabaseError as ex: - message = str(ex) - if _wants_json(request, is_json, data): - return _block_framing(_error([message], 400)) - return await self._render_form( - request, - db, - sql=sql, - parameter_values=params, - analysis=analysis, - execution_message=message, - execution_ok=False, - status=400, - ) - - message = "Query executed, {} row{} affected".format( - cursor.rowcount, "" if cursor.rowcount == 1 else "s" - ) - if _wants_json(request, is_json, data): - return _block_framing( - Response.json( - { - "ok": True, - "message": message, - "rowcount": cursor.rowcount, - "analysis": _analysis_rows(analysis), - } - ) - ) - - inserted_row_url = await _inserted_row_url(self.ds, db, analysis, cursor) - execution_links = ( - [{"href": inserted_row_url, "label": "View row"}] - if inserted_row_url - else [] - ) - return await self._render_form( - request, - db, - sql=sql, - parameter_values={name: params.get(name, "") for name in parameter_names}, - analysis=analysis, - execution_message=message, - execution_links=execution_links, - execution_ok=True, - ) - - -class ExecuteWriteAnalyzeView(BaseView): - name = "execute-write-analyze" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing( - _error(["Permission denied: need execute-write-sql"], 403) - ) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - sql = request.args.get("sql") or "" - return _block_framing( - Response.json( - await _execute_write_analysis_data(self.ds, db, sql, request.actor) - ) - ) - - -class QueryParametersView(BaseView): - name = "query-parameters" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need execute-sql"], 403)) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - try: - parameters = _derived_query_parameters(request.args.get("sql") or "") - except QueryValidationError as ex: - return _block_framing(_error([ex.message], ex.status)) - return _block_framing(Response.json({"ok": True, "parameters": parameters})) - - -def _query_list_url(path, query_string, *, set_args=None, remove_args=None): - set_args = set_args or {} - remove_args = set(remove_args or ()) - skip = set(set_args) | remove_args | {"_next"} - pairs = [ - (key, value) - for key, value in parse_qsl(query_string, keep_blank_values=True) - if key not in skip - ] - for key, value in set_args.items(): - if value not in (None, ""): - pairs.append((key, value)) - return path + (("?" + urlencode(pairs)) if pairs else "") - - -class QueryListView(BaseView): - name = "query-list" - - async def database_name(self, request): - return (await self.ds.resolve_database(request)).name - - def query_list_path(self, database): - return self.ds.urls.database(database) + "/-/queries" - - async def get(self, request): - database = await self.database_name(request) - format_ = request.url_vars.get("format") or "html" - try: - limit = _query_list_limit( - request.args.get("_size"), - default=20 if format_ == "html" else 50, - ) - is_write = _as_optional_bool(request.args.get("is_write"), "is_write") - is_private = _as_optional_bool(request.args.get("is_private"), "is_private") - except QueryValidationError as ex: - return _error([ex.message], ex.status) - - page = await self.ds.list_queries( - database, - actor=request.actor, - limit=limit, - cursor=request.args.get("_next"), - q=request.args.get("q") or None, - is_write=is_write, - is_private=is_private, - source=request.args.get("source") or None, - owner_id=request.args.get("owner_id") or None, - include_private=True, - ) - query_list_path = self.query_list_path(database) - next_url = None - if page["next"]: - pairs = [ - (key, value) - for key, value in parse_qsl( - request.query_string, keep_blank_values=True - ) - if key != "_next" - ] - pairs.append(("_next", page["next"])) - next_url = "{}?{}".format( - query_list_path, - urlencode(pairs), - ) - - current_filters = { - "actor": request.actor, - "q": request.args.get("q") or None, - "is_write": is_write, - "is_private": is_private, - "source": request.args.get("source") or None, - "owner_id": request.args.get("owner_id") or None, - } - - async def facet_count(field, value): - if current_filters[field] is not None and current_filters[field] != value: - return 0 - filters = dict(current_filters) - filters[field] = value - return await self.ds.count_queries(database, **filters) - - def facet_href(field, value): - if current_filters[field] == value: - return _query_list_url( - query_list_path, - request.query_string, - remove_args=[field], - ) - if current_filters[field] is not None: - return None - return _query_list_url( - query_list_path, - request.query_string, - set_args={field: str(int(value))}, - ) - - async def facet_item(label, field, value): - count = await facet_count(field, value) - active = current_filters[field] == value - if not active and not count: - return None - return { - "label": label, - "count": count, - "href": facet_href(field, value) if active or count else None, - "active": active, - } - - async def facet_items(items): - return [ - item - for item in [ - await facet_item(label, field, value) - for label, field, value in items - ] - if item is not None - ] - - facets = [ - { - "title": "Mode", - "items": await facet_items( - [ - ("Read-only", "is_write", False), - ("Writable", "is_write", True), - ] - ), - }, - { - "title": "Visibility", - "items": await facet_items( - [ - ("Not private", "is_private", False), - ("Private", "is_private", True), - ] - ), - }, - ] - - data = { - "ok": True, - "database": database, - "database_color": ( - self.ds.get_database(database).color if database is not None else None - ), - "queries": page["queries"], - "next": page["next"], - "next_url": next_url, - "has_more": page["has_more"], - "limit": page["limit"], - "show_private_note": any(query["is_private"] for query in page["queries"]), - "show_trusted_note": any(query["is_trusted"] for query in page["queries"]), - "query_list_path": query_list_path, - "show_database": database is None, - "facets": facets, - "filters": { - "q": request.args.get("q") or "", - "is_write": request.args.get("is_write") or "", - "is_private": request.args.get("is_private") or "", - "source": request.args.get("source") or "", - "owner_id": request.args.get("owner_id") or "", - }, - } - if format_ == "json": - return Response.json(data) - return await self.render( - ["query_list.html"], - request, - data, - ) - - -class GlobalQueryListView(QueryListView): - name = "global-query-list" - - async def database_name(self, request): - return None - - def query_list_path(self, database): - return self.ds.urls.path("/-/queries") - - -class QueryCreateView(BaseView): - name = "query-create" - has_json_alternate = False - - async def _render_form( - self, - request, - db, - *, - sql="", - name="", - title="", - description="", - is_private=True, - status=200, - ): - response = await self.render( - ["query_create.html"], - request, - await _query_create_form_context( - self.ds, - request, - db, - sql=sql, - name=name, - title=title, - description=description, - is_private=is_private, - ), - ) - response.status = status - return response - - async def get(self, request): - db = await self.ds.resolve_database(request) - await self.ds.ensure_permission( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - await self.ds.ensure_permission( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - - return await self._render_form(request, db, sql=request.args.get("sql") or "") - - -class QueryCreateAnalyzeView(BaseView): - name = "query-create-analyze" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need execute-sql"], 403)) - if not await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need store-query"], 403)) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - sql = request.args.get("sql") or "" - return _block_framing( - Response.json( - await _query_create_analysis_data(self.ds, db, sql, request.actor) - ) - ) - - -class QueryStoreView(QueryCreateView): - name = "query-store" - - async def _error_response(self, request, db, query_data, message, status): - message = _query_create_form_error_message(message) - self.ds.add_message(request, message, self.ds.ERROR) - return await self._render_form( - request, - db, - sql=query_data.get("sql") or "", - name=query_data.get("name") or "", - title=query_data.get("title") or "", - description=query_data.get("description") or "", - is_private=_as_bool(query_data.get("is_private", True)), - status=status, - ) - - async def post(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _error(["Permission denied: need execute-sql"], 403) - if not await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _error(["Permission denied: need store-query"], 403) - - is_json = False - query_data = {} - try: - data, is_json = await _json_or_form_payload(request) - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - query_data = data.get("query") if is_json else data - if not isinstance(query_data, dict): - raise QueryValidationError("JSON must contain a query dictionary") - prepared = await _prepare_query_create(self.ds, request, db, query_data) - except QueryValidationError as ex: - if not is_json and isinstance(query_data, dict): - return await self._error_response( - request, db, query_data, ex.message, ex.status - ) - return _error([ex.message], ex.status) - - prepared.pop("analysis") - name = prepared.pop("name") - try: - await self.ds.add_query(db.name, name, replace=False, **prepared) - except sqlite3.IntegrityError as ex: - if not is_json and isinstance(query_data, dict): - return await self._error_response(request, db, query_data, str(ex), 400) - return _error([str(ex)], 400) - - query = await self.ds.get_query(db.name, name) - if is_json: - return Response.json({"ok": True, "query": query}, status=201) - self.ds.add_message(request, "Query saved", self.ds.INFO) - return Response.redirect(self.ds.urls.path(self.ds.urls.table(db.name, name))) - - -class QueryDefinitionView(BaseView): - name = "query-definition" - - async def get(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - query = await self.ds.get_query(db.name, query_name) - if query is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="view-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - return Response.json({"ok": True, "query": query}) - - -class QueryUpdateView(BaseView): - name = "query-update" - - async def post(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - existing = await self.ds.get_query(db.name, query_name) - if existing is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="update-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied: need update-query"], 403) - - try: - data, _ = await _json_or_form_payload(request) - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - invalid_keys = set(data) - {"update", "return"} - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(invalid_keys)) - ) - update = data.get("update") - if not isinstance(update, dict): - raise QueryValidationError("JSON must contain an update dictionary") - if "sql" in update and not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - raise QueryValidationError( - "Permission denied: need execute-sql", status=403 - ) - update_kwargs = await _prepare_query_update( - self.ds, request, db, existing, update - ) - except QueryValidationError as ex: - return _error([ex.message], ex.status) - - await self.ds.update_query(db.name, query_name, **update_kwargs) - if data.get("return"): - return Response.json( - { - "ok": True, - "query": await self.ds.get_query(db.name, query_name), - } - ) - return Response.json({"ok": True}) - - -class QueryDeleteView(BaseView): - name = "query-delete" - - async def post(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - existing = await self.ds.get_query(db.name, query_name) - if existing is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="delete-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied: need delete-query"], 403) - await self.ds.remove_query(db.name, query_name) - return Response.json({"ok": True}) - - class QueryView(View): async def post(self, request, datasette): from datasette.app import TableNotFound @@ -2435,22 +1220,6 @@ class TableCreateView(BaseView): return Response.json(details, status=201) -async def _table_columns(datasette, database_name): - internal_db = datasette.get_internal_database() - result = await internal_db.execute( - "select table_name, name from catalog_columns where database_name = ?", - [database_name], - ) - table_columns = {} - for row in result.rows: - table_columns.setdefault(row["table_name"], []).append(row["name"]) - # Add views - db = datasette.get_database(database_name) - for view_name in await db.view_names(): - table_columns[view_name] = [] - return table_columns - - async def display_rows(datasette, database, request, rows, columns): display_rows = [] truncate_cells = datasette.setting("truncate_cells_html") diff --git a/datasette/views/execute_write.py b/datasette/views/execute_write.py new file mode 100644 index 00000000..0054300c --- /dev/null +++ b/datasette/views/execute_write.py @@ -0,0 +1,257 @@ +from urllib.parse import urlencode + +from datasette.resources import DatabaseResource +from datasette.utils import sqlite3 +from datasette.utils.asgi import Response + +from .base import BaseView, _error +from .query_helpers import ( + QueryValidationError, + _analysis_is_write, + _analysis_rows, + _analysis_rows_with_permissions, + _block_framing, + _coerce_execute_write_payload, + _derived_query_parameters, + _execute_write_analysis_data, + _inserted_row_url, + _json_or_form_payload, + _prepare_execute_write, + _table_columns, + _wants_json, +) + + +class ExecuteWriteView(BaseView): + name = "execute-write" + has_json_alternate = False + + async def _render_form( + self, + request, + db, + *, + sql="", + parameter_values=None, + analysis=None, + analysis_error=None, + execution_message=None, + execution_links=None, + execution_ok=None, + status=200, + ): + parameter_values = parameter_values or {} + execution_links = execution_links or [] + parameter_names = [] + analysis_rows = [] + table_columns = await _table_columns(self.ds, db.name) + hidden_table_names = set(await db.hidden_table_names()) + write_template_tables = { + table: columns + for table, columns in table_columns.items() + if columns and table not in hidden_table_names + } + if sql and analysis_error is None: + try: + parameter_names = _derived_query_parameters(sql) + if analysis is None: + params = {parameter: "" for parameter in parameter_names} + analysis = await db.analyze_sql(sql, params) + if _analysis_is_write(analysis): + analysis_rows = await _analysis_rows_with_permissions( + self.ds, analysis, request.actor + ) + else: + analysis_error = ( + "Use /-/query for read-only SQL; " + "this endpoint only executes writes" + ) + except (QueryValidationError, sqlite3.DatabaseError) as ex: + analysis_error = getattr(ex, "message", str(ex)) + + allow_save_query = await self.ds.allowed( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ) and await self.ds.allowed( + action="store-query", + resource=DatabaseResource(db.name), + actor=request.actor, + ) + save_query_base_url = None + save_query_url = None + if allow_save_query: + save_query_base_url = self.ds.urls.database(db.name) + "/-/queries/store" + if ( + sql + and analysis_error is None + and not any(row["allowed"] is False for row in analysis_rows) + ): + save_query_url = save_query_base_url + "?" + urlencode({"sql": sql}) + + response = await self.render( + ["execute_write.html"], + request, + { + "database": db.name, + "database_color": db.color, + "sql": sql, + "parameter_names": parameter_names, + "parameter_values": parameter_values, + "analysis_error": analysis_error, + "analysis_rows": [ + row for row in analysis_rows if row["operation"] != "read" + ], + "execution_message": execution_message, + "execution_links": execution_links, + "execution_ok": execution_ok, + "execute_disabled": bool( + (not sql) + or analysis_error + or any(row["allowed"] is False for row in analysis_rows) + ), + "table_columns": table_columns, + "write_template_tables": write_template_tables, + "save_query_url": save_query_url, + "save_query_base_url": save_query_base_url, + }, + ) + response.status = status + return _block_framing(response) + + async def get(self, request): + db = await self.ds.resolve_database(request) + await self.ds.ensure_permission( + action="execute-write-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ) + if not db.is_mutable: + return _block_framing( + _error( + ["Cannot execute write SQL because this database is immutable."], + 403, + ) + ) + return await self._render_form( + request, + db, + sql=request.args.get("sql") or "", + ) + + async def post(self, request): + db = await self.ds.resolve_database(request) + if not await self.ds.allowed( + action="execute-write-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _block_framing( + _error(["Permission denied: need execute-write-sql"], 403) + ) + if not db.is_mutable: + return _block_framing(_error(["Database is immutable"], 403)) + + data = {} + is_json = request.headers.get("content-type", "").startswith("application/json") + sql = "" + provided_params = {} + try: + data, is_json = await _json_or_form_payload(request) + sql, provided_params = _coerce_execute_write_payload(data, is_json) + parameter_names, params, analysis = await _prepare_execute_write( + self.ds, db, sql, provided_params, request.actor + ) + except QueryValidationError as ex: + if _wants_json(request, is_json, data): + return _block_framing(_error([ex.message], ex.status)) + return await self._render_form( + request, + db, + sql=sql or "", + parameter_values=provided_params, + analysis_error=ex.message, + execution_message=ex.message, + execution_ok=False, + status=ex.status, + ) + + try: + cursor = await db.execute_write(sql, params, request=request) + except sqlite3.DatabaseError as ex: + message = str(ex) + if _wants_json(request, is_json, data): + return _block_framing(_error([message], 400)) + return await self._render_form( + request, + db, + sql=sql, + parameter_values=params, + analysis=analysis, + execution_message=message, + execution_ok=False, + status=400, + ) + + message = "Query executed, {} row{} affected".format( + cursor.rowcount, "" if cursor.rowcount == 1 else "s" + ) + if _wants_json(request, is_json, data): + return _block_framing( + Response.json( + { + "ok": True, + "message": message, + "rowcount": cursor.rowcount, + "analysis": _analysis_rows(analysis), + } + ) + ) + + inserted_row_url = await _inserted_row_url(self.ds, db, analysis, cursor) + execution_links = ( + [{"href": inserted_row_url, "label": "View row"}] + if inserted_row_url + else [] + ) + return await self._render_form( + request, + db, + sql=sql, + parameter_values={name: params.get(name, "") for name in parameter_names}, + analysis=analysis, + execution_message=message, + execution_links=execution_links, + execution_ok=True, + ) + + +class ExecuteWriteAnalyzeView(BaseView): + name = "execute-write-analyze" + has_json_alternate = False + + async def get(self, request): + db = await self.ds.resolve_database(request) + if not await self.ds.allowed( + action="execute-write-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _block_framing( + _error(["Permission denied: need execute-write-sql"], 403) + ) + + invalid_keys = set(request.args) - {"sql"} + if invalid_keys: + return _block_framing( + _error( + ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], + 400, + ) + ) + sql = request.args.get("sql") or "" + return _block_framing( + Response.json( + await _execute_write_analysis_data(self.ds, db, sql, request.actor) + ) + ) diff --git a/datasette/views/query_helpers.py b/datasette/views/query_helpers.py new file mode 100644 index 00000000..d8763a6d --- /dev/null +++ b/datasette/views/query_helpers.py @@ -0,0 +1,558 @@ +import json +import re + +from datasette.resources import DatabaseResource, TableResource +from datasette.utils import ( + named_parameters as derive_named_parameters, + escape_sqlite, + path_from_row_pks, + sqlite3, + validate_sql_select, + InvalidSql, +) +from datasette.utils.asgi import Forbidden + +_query_name_re = re.compile(r"^[^/\.\n]+$") + +_query_fields = { + "sql", + "title", + "description", + "description_html", + "hide_sql", + "fragment", + "parameters", + "params", + "is_private", + "on_success_message", + "on_success_message_sql", + "on_success_redirect", + "on_error_message", + "on_error_redirect", +} + +_query_create_fields = _query_fields | {"name", "mode", "csrftoken"} +_query_update_fields = _query_fields +_query_write_fields = { + "on_success_message", + "on_success_message_sql", + "on_success_redirect", + "on_error_message", + "on_error_redirect", +} + + +class QueryValidationError(Exception): + def __init__(self, message, status=400): + self.message = message + self.status = status + + +def _actor_id(actor): + if isinstance(actor, dict): + return actor.get("id") + return None + + +def _as_bool(value): + if isinstance(value, bool): + return value + if value is None: + return False + if isinstance(value, int): + return bool(value) + if isinstance(value, str): + return value.lower() in {"1", "true", "t", "yes", "on"} + return bool(value) + + +def _as_optional_bool(value, name): + if value is None or value == "": + return None + if isinstance(value, bool): + return value + if isinstance(value, int): + return bool(value) + if isinstance(value, str): + lowered = value.lower() + if lowered in {"1", "true", "t", "yes", "on"}: + return True + if lowered in {"0", "false", "f", "no", "off"}: + return False + raise QueryValidationError("{} must be 0 or 1".format(name)) + + +def _query_list_limit(value, default=50): + if value in (None, ""): + return default + try: + return min(max(1, int(value)), 1000) + except ValueError as ex: + raise QueryValidationError("_size must be an integer") from ex + + +def _derived_query_parameters(sql): + parameters = [] + seen = set() + for parameter in derive_named_parameters(sql): + if parameter.startswith("_"): + raise QueryValidationError("Magic parameters are not allowed") + if parameter not in seen: + parameters.append(parameter) + seen.add(parameter) + return parameters + + +def _coerce_query_parameters(value, derived): + if value is None: + return derived + if isinstance(value, str): + parameters = [ + parameter.strip() + for parameter in re.split(r"[\s,]+", value) + if parameter.strip() + ] + elif isinstance(value, list): + parameters = value + else: + raise QueryValidationError("parameters must be a list of strings") + if not all(isinstance(parameter, str) for parameter in parameters): + raise QueryValidationError("parameters must be a list of strings") + if any(parameter.startswith("_") for parameter in parameters): + raise QueryValidationError("Magic parameters are not allowed") + if set(parameters) != set(derived): + raise QueryValidationError("parameters must match SQL named parameters") + return parameters + + +def _analysis_is_write(analysis): + return any( + access.operation in {"insert", "update", "delete"} + for access in analysis.table_accesses + ) + + +def _block_framing(response): + response.headers["Content-Security-Policy"] = "frame-ancestors 'none'" + response.headers["X-Frame-Options"] = "DENY" + return response + + +def _wants_json(request, is_json, data): + return ( + is_json + or request.headers.get("accept") == "application/json" + or (isinstance(data, dict) and data.get("_json")) + ) + + +def _query_create_form_error_message(message): + return { + "Query name is required": "URL is required", + "Invalid query name": "Invalid URL", + "Query name conflicts with a table or view": ( + "URL conflicts with an existing table or view" + ), + "Query already exists": "A query already exists at that URL", + }.get(message, message) + + +async def _json_or_form_payload(request): + content_type = request.headers.get("content-type", "") + if content_type.startswith("application/json"): + body = await request.post_body() + try: + return json.loads(body or b"{}"), True + except json.JSONDecodeError as e: + raise QueryValidationError("Invalid JSON: {}".format(e)) + return await request.post_vars(), False + + +async def _check_query_name(db, name, *, existing=False): + if not name or not isinstance(name, str): + raise QueryValidationError("Query name is required") + if not _query_name_re.match(name): + raise QueryValidationError("Invalid query name") + if not existing and (await db.table_exists(name) or await db.view_exists(name)): + raise QueryValidationError("Query name conflicts with a table or view") + + +async def _analyze_user_query(datasette, db, sql, *, actor): + if not sql or not isinstance(sql, str): + raise QueryValidationError("SQL is required") + derived = _derived_query_parameters(sql) + params = {parameter: "" for parameter in derived} + try: + analysis = await db.analyze_sql(sql, params) + except sqlite3.DatabaseError as ex: + raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex + + is_write = _analysis_is_write(analysis) + if is_write: + try: + await datasette.ensure_query_write_permissions( + db.name, sql, actor=actor, analysis=analysis + ) + except Forbidden as ex: + raise QueryValidationError(str(ex), status=403) from ex + else: + try: + validate_sql_select(sql) + except InvalidSql as ex: + raise QueryValidationError(str(ex)) from ex + return is_write, derived, analysis + + +def _analysis_rows(analysis): + write_actions = { + "insert": "insert-row", + "update": "update-row", + "delete": "delete-row", + } + return [ + { + "operation": access.operation, + "database": access.database, + "table": access.table, + "required_permission": write_actions.get(access.operation, ""), + "source": access.source, + } + for access in analysis.table_accesses + ] + + +async def _analysis_rows_with_permissions(datasette, analysis, actor): + rows = _analysis_rows(analysis) + for row in rows: + permission = row["required_permission"] + if permission: + row["allowed"] = await datasette.allowed( + action=permission, + resource=TableResource(row["database"], row["table"]), + actor=actor, + ) + else: + row["allowed"] = None + return rows + + +def _coerce_execute_write_payload(data, is_json): + if not isinstance(data, dict): + raise QueryValidationError("JSON must be a dictionary") + if is_json: + invalid_keys = set(data) - {"sql", "params"} + if invalid_keys: + raise QueryValidationError( + "Invalid keys: {}".format(", ".join(sorted(invalid_keys))) + ) + params = data.get("params") or {} + else: + params = { + key: value + for key, value in data.items() + if key not in {"sql", "csrftoken", "_json"} + } + if not isinstance(params, dict): + raise QueryValidationError("params must be a dictionary") + return data.get("sql"), params + + +async def _prepare_execute_write(datasette, db, sql, params, actor): + if not sql or not isinstance(sql, str): + raise QueryValidationError("SQL is required") + parameter_names = _derived_query_parameters(sql) + extra_params = set(params) - set(parameter_names) + if extra_params: + raise QueryValidationError( + "Unknown parameters: {}".format(", ".join(sorted(extra_params))) + ) + params = {name: params.get(name, "") for name in parameter_names} + try: + analysis = await db.analyze_sql(sql, params) + except sqlite3.DatabaseError as ex: + raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex + if not _analysis_is_write(analysis): + raise QueryValidationError( + "Use /-/query for read-only SQL; this endpoint only executes writes" + ) + try: + await datasette.ensure_query_write_permissions( + db.name, sql, actor=actor, analysis=analysis + ) + except Forbidden as ex: + raise QueryValidationError(str(ex), status=403) from ex + return parameter_names, params, analysis + + +async def _ensure_stored_query_execution_permissions(datasette, db, query, actor): + if query.get("is_trusted"): + return + if query.get("write"): + await datasette.ensure_permission( + action="execute-write-sql", + resource=DatabaseResource(db.name), + actor=actor, + ) + await datasette.ensure_query_write_permissions( + db.name, query["sql"], actor=actor + ) + else: + await datasette.ensure_permission( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=actor, + ) + + +async def _execute_write_analysis_data(datasette, db, sql, actor): + parameter_names = [] + analysis_rows = [] + analysis_error = None + if sql: + try: + parameter_names = _derived_query_parameters(sql) + params = {parameter: "" for parameter in parameter_names} + analysis = await db.analyze_sql(sql, params) + if _analysis_is_write(analysis): + analysis_rows = await _analysis_rows_with_permissions( + datasette, analysis, actor + ) + else: + analysis_error = ( + "Use /-/query for read-only SQL; " + "this endpoint only executes writes" + ) + except (QueryValidationError, sqlite3.DatabaseError) as ex: + analysis_error = getattr(ex, "message", str(ex)) + return { + "ok": analysis_error is None, + "parameters": parameter_names, + "analysis_error": analysis_error, + "analysis_rows": [row for row in analysis_rows if row["operation"] != "read"], + "execute_disabled": bool( + (not sql) + or analysis_error + or any(row["allowed"] is False for row in analysis_rows) + ), + } + + +async def _query_create_analysis_data(datasette, db, sql, actor): + has_sql = bool(sql and sql.strip()) + parameter_names = [] + analysis_rows = [] + analysis_error = None + if has_sql: + try: + parameter_names = _derived_query_parameters(sql) + params = {parameter: "" for parameter in parameter_names} + analysis = await db.analyze_sql(sql, params) + analysis_rows = await _analysis_rows_with_permissions( + datasette, analysis, actor + ) + except (QueryValidationError, sqlite3.DatabaseError) as ex: + analysis_error = getattr(ex, "message", str(ex)) + return { + "ok": analysis_error is None, + "parameters": parameter_names, + "analysis_error": analysis_error, + "analysis_rows": analysis_rows, + "has_sql": has_sql, + "analysis_is_write": bool( + analysis_rows and any(row["required_permission"] for row in analysis_rows) + ), + "save_disabled": bool( + (not has_sql) + or analysis_error + or any(row["allowed"] is False for row in analysis_rows) + ), + } + + +async def _query_create_form_context( + datasette, + request, + db, + *, + sql="", + name="", + title="", + description="", + is_private=True, +): + analysis_data = await _query_create_analysis_data(datasette, db, sql, request.actor) + return { + "database": db.name, + "database_color": db.color, + "sql": sql, + "name": name, + "title": title, + "description": description, + "is_private": is_private, + **analysis_data, + } + + +async def _inserted_row_url(datasette, db, analysis, cursor): + if cursor.rowcount != 1: + return None + lastrowid = getattr(cursor, "lastrowid", None) + if lastrowid is None: + return None + direct_inserts = [ + access + for access in analysis.table_accesses + if access.operation == "insert" + and access.source is None + and access.database == db.name + ] + if len(direct_inserts) != 1: + return None + table = direct_inserts[0].table + pks = await db.primary_keys(table) + use_rowid = not pks + select = ( + "rowid" + if use_rowid + else ", ".join(escape_sqlite(primary_key) for primary_key in pks) + ) + try: + result = await db.execute( + "select {} from {} where rowid = ?".format(select, escape_sqlite(table)), + [lastrowid], + ) + except sqlite3.DatabaseError: + return None + row = result.first() + if row is None: + return None + row_path = path_from_row_pks(row, pks, use_rowid) + return datasette.urls.row(db.name, table, row_path) + + +def _apply_query_data_types(data): + typed = dict(data) + for key in ("hide_sql", "is_private"): + if key in typed: + typed[key] = _as_bool(typed[key]) + return typed + + +async def _prepare_query_create(datasette, request, db, data): + invalid_keys = set(data) - _query_create_fields + if invalid_keys: + raise QueryValidationError("Invalid keys: {}".format(", ".join(invalid_keys))) + + data = _apply_query_data_types(data) + name = data.get("name") + await _check_query_name(db, name) + if await datasette.get_query(db.name, name) is not None: + raise QueryValidationError("Query already exists") + + is_write, derived, analysis = await _analyze_user_query( + datasette, + db, + data.get("sql"), + actor=request.actor, + ) + if not is_write and any(data.get(field) for field in _query_write_fields): + raise QueryValidationError("Writable query fields require writable SQL") + + parameters = _coerce_query_parameters( + data.get("parameters", data.get("params")), + derived, + ) + return { + "name": name, + "sql": data["sql"], + "title": data.get("title"), + "description": data.get("description"), + "description_html": data.get("description_html"), + "hide_sql": _as_bool(data.get("hide_sql")), + "fragment": data.get("fragment"), + "parameters": parameters, + "is_write": is_write, + "is_private": _as_bool(data.get("is_private", True)), + "is_trusted": False, + "source": "user", + "owner_id": _actor_id(request.actor), + "on_success_message": data.get("on_success_message"), + "on_success_message_sql": data.get("on_success_message_sql"), + "on_success_redirect": data.get("on_success_redirect"), + "on_error_message": data.get("on_error_message"), + "on_error_redirect": data.get("on_error_redirect"), + "analysis": analysis, + } + + +async def _prepare_query_update(datasette, request, db, existing, update): + invalid_keys = set(update) - _query_update_fields + if invalid_keys: + raise QueryValidationError("Invalid keys: {}".format(", ".join(invalid_keys))) + + update = _apply_query_data_types(update) + sql = update.get("sql", existing["sql"]) + query_is_write = existing["is_write"] + derived = _derived_query_parameters(sql) + parameters = None + + if "sql" in update: + query_is_write, derived, _ = await _analyze_user_query( + datasette, + db, + sql, + actor=request.actor, + ) + + if "parameters" in update or "params" in update: + parameters = _coerce_query_parameters( + update.get("parameters", update.get("params")), + derived, + ) + elif "sql" in update: + parameters = derived + + if not query_is_write and any(update.get(field) for field in _query_write_fields): + raise QueryValidationError("Writable query fields require writable SQL") + + field_values = { + "sql": sql, + "title": update.get("title"), + "description": update.get("description"), + "description_html": update.get("description_html"), + "hide_sql": update.get("hide_sql"), + "fragment": update.get("fragment"), + "parameters": parameters, + "is_write": query_is_write, + "is_private": update.get("is_private"), + "on_success_message": update.get("on_success_message"), + "on_success_message_sql": update.get("on_success_message_sql"), + "on_success_redirect": update.get("on_success_redirect"), + "on_error_message": update.get("on_error_message"), + "on_error_redirect": update.get("on_error_redirect"), + } + update_kwargs = {} + for field_name, value in field_values.items(): + if field_name in update: + update_kwargs[field_name] = value + if parameters is not None: + update_kwargs["parameters"] = parameters + if "sql" in update: + update_kwargs["is_write"] = query_is_write + return update_kwargs + + +async def _table_columns(datasette, database_name): + internal_db = datasette.get_internal_database() + result = await internal_db.execute( + "select table_name, name from catalog_columns where database_name = ?", + [database_name], + ) + table_columns = {} + for row in result.rows: + table_columns.setdefault(row["table_name"], []).append(row["name"]) + # Add views + db = datasette.get_database(database_name) + for view_name in await db.view_names(): + table_columns[view_name] = [] + return table_columns diff --git a/datasette/views/stored_queries.py b/datasette/views/stored_queries.py new file mode 100644 index 00000000..b3813f14 --- /dev/null +++ b/datasette/views/stored_queries.py @@ -0,0 +1,470 @@ +from urllib.parse import parse_qsl, urlencode + +from datasette.resources import DatabaseResource, QueryResource +from datasette.utils import sqlite3, tilde_decode +from datasette.utils.asgi import Response + +from .base import BaseView, _error +from .query_helpers import ( + QueryValidationError, + _as_bool, + _as_optional_bool, + _block_framing, + _derived_query_parameters, + _json_or_form_payload, + _prepare_query_create, + _prepare_query_update, + _query_create_analysis_data, + _query_create_form_context, + _query_create_form_error_message, + _query_list_limit, +) + + +class QueryParametersView(BaseView): + name = "query-parameters" + has_json_alternate = False + + async def get(self, request): + db = await self.ds.resolve_database(request) + if not await self.ds.allowed( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _block_framing(_error(["Permission denied: need execute-sql"], 403)) + + invalid_keys = set(request.args) - {"sql"} + if invalid_keys: + return _block_framing( + _error( + ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], + 400, + ) + ) + try: + parameters = _derived_query_parameters(request.args.get("sql") or "") + except QueryValidationError as ex: + return _block_framing(_error([ex.message], ex.status)) + return _block_framing(Response.json({"ok": True, "parameters": parameters})) + + +def _query_list_url(path, query_string, *, set_args=None, remove_args=None): + set_args = set_args or {} + remove_args = set(remove_args or ()) + skip = set(set_args) | remove_args | {"_next"} + pairs = [ + (key, value) + for key, value in parse_qsl(query_string, keep_blank_values=True) + if key not in skip + ] + for key, value in set_args.items(): + if value not in (None, ""): + pairs.append((key, value)) + return path + (("?" + urlencode(pairs)) if pairs else "") + + +class QueryListView(BaseView): + name = "query-list" + + async def database_name(self, request): + return (await self.ds.resolve_database(request)).name + + def query_list_path(self, database): + return self.ds.urls.database(database) + "/-/queries" + + async def get(self, request): + database = await self.database_name(request) + format_ = request.url_vars.get("format") or "html" + try: + limit = _query_list_limit( + request.args.get("_size"), + default=20 if format_ == "html" else 50, + ) + is_write = _as_optional_bool(request.args.get("is_write"), "is_write") + is_private = _as_optional_bool(request.args.get("is_private"), "is_private") + except QueryValidationError as ex: + return _error([ex.message], ex.status) + + page = await self.ds.list_queries( + database, + actor=request.actor, + limit=limit, + cursor=request.args.get("_next"), + q=request.args.get("q") or None, + is_write=is_write, + is_private=is_private, + source=request.args.get("source") or None, + owner_id=request.args.get("owner_id") or None, + include_private=True, + ) + query_list_path = self.query_list_path(database) + next_url = None + if page["next"]: + pairs = [ + (key, value) + for key, value in parse_qsl( + request.query_string, keep_blank_values=True + ) + if key != "_next" + ] + pairs.append(("_next", page["next"])) + next_url = "{}?{}".format( + query_list_path, + urlencode(pairs), + ) + + current_filters = { + "actor": request.actor, + "q": request.args.get("q") or None, + "is_write": is_write, + "is_private": is_private, + "source": request.args.get("source") or None, + "owner_id": request.args.get("owner_id") or None, + } + + async def facet_count(field, value): + if current_filters[field] is not None and current_filters[field] != value: + return 0 + filters = dict(current_filters) + filters[field] = value + return await self.ds.count_queries(database, **filters) + + def facet_href(field, value): + if current_filters[field] == value: + return _query_list_url( + query_list_path, + request.query_string, + remove_args=[field], + ) + if current_filters[field] is not None: + return None + return _query_list_url( + query_list_path, + request.query_string, + set_args={field: str(int(value))}, + ) + + async def facet_item(label, field, value): + count = await facet_count(field, value) + active = current_filters[field] == value + if not active and not count: + return None + return { + "label": label, + "count": count, + "href": facet_href(field, value) if active or count else None, + "active": active, + } + + async def facet_items(items): + return [ + item + for item in [ + await facet_item(label, field, value) + for label, field, value in items + ] + if item is not None + ] + + facets = [ + { + "title": "Mode", + "items": await facet_items( + [ + ("Read-only", "is_write", False), + ("Writable", "is_write", True), + ] + ), + }, + { + "title": "Visibility", + "items": await facet_items( + [ + ("Not private", "is_private", False), + ("Private", "is_private", True), + ] + ), + }, + ] + + data = { + "ok": True, + "database": database, + "database_color": ( + self.ds.get_database(database).color if database is not None else None + ), + "queries": page["queries"], + "next": page["next"], + "next_url": next_url, + "has_more": page["has_more"], + "limit": page["limit"], + "show_private_note": any(query["is_private"] for query in page["queries"]), + "show_trusted_note": any(query["is_trusted"] for query in page["queries"]), + "query_list_path": query_list_path, + "show_database": database is None, + "facets": facets, + "filters": { + "q": request.args.get("q") or "", + "is_write": request.args.get("is_write") or "", + "is_private": request.args.get("is_private") or "", + "source": request.args.get("source") or "", + "owner_id": request.args.get("owner_id") or "", + }, + } + if format_ == "json": + return Response.json(data) + return await self.render( + ["query_list.html"], + request, + data, + ) + + +class GlobalQueryListView(QueryListView): + name = "global-query-list" + + async def database_name(self, request): + return None + + def query_list_path(self, database): + return self.ds.urls.path("/-/queries") + + +class QueryCreateView(BaseView): + name = "query-create" + has_json_alternate = False + + async def _render_form( + self, + request, + db, + *, + sql="", + name="", + title="", + description="", + is_private=True, + status=200, + ): + response = await self.render( + ["query_create.html"], + request, + await _query_create_form_context( + self.ds, + request, + db, + sql=sql, + name=name, + title=title, + description=description, + is_private=is_private, + ), + ) + response.status = status + return response + + async def get(self, request): + db = await self.ds.resolve_database(request) + await self.ds.ensure_permission( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ) + await self.ds.ensure_permission( + action="store-query", + resource=DatabaseResource(db.name), + actor=request.actor, + ) + + return await self._render_form(request, db, sql=request.args.get("sql") or "") + + +class QueryCreateAnalyzeView(BaseView): + name = "query-create-analyze" + has_json_alternate = False + + async def get(self, request): + db = await self.ds.resolve_database(request) + if not await self.ds.allowed( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _block_framing(_error(["Permission denied: need execute-sql"], 403)) + if not await self.ds.allowed( + action="store-query", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _block_framing(_error(["Permission denied: need store-query"], 403)) + + invalid_keys = set(request.args) - {"sql"} + if invalid_keys: + return _block_framing( + _error( + ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], + 400, + ) + ) + sql = request.args.get("sql") or "" + return _block_framing( + Response.json( + await _query_create_analysis_data(self.ds, db, sql, request.actor) + ) + ) + + +class QueryStoreView(QueryCreateView): + name = "query-store" + + async def _error_response(self, request, db, query_data, message, status): + message = _query_create_form_error_message(message) + self.ds.add_message(request, message, self.ds.ERROR) + return await self._render_form( + request, + db, + sql=query_data.get("sql") or "", + name=query_data.get("name") or "", + title=query_data.get("title") or "", + description=query_data.get("description") or "", + is_private=_as_bool(query_data.get("is_private", True)), + status=status, + ) + + async def post(self, request): + db = await self.ds.resolve_database(request) + if not await self.ds.allowed( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _error(["Permission denied: need execute-sql"], 403) + if not await self.ds.allowed( + action="store-query", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + return _error(["Permission denied: need store-query"], 403) + + is_json = False + query_data = {} + try: + data, is_json = await _json_or_form_payload(request) + if not isinstance(data, dict): + raise QueryValidationError("JSON must be a dictionary") + query_data = data.get("query") if is_json else data + if not isinstance(query_data, dict): + raise QueryValidationError("JSON must contain a query dictionary") + prepared = await _prepare_query_create(self.ds, request, db, query_data) + except QueryValidationError as ex: + if not is_json and isinstance(query_data, dict): + return await self._error_response( + request, db, query_data, ex.message, ex.status + ) + return _error([ex.message], ex.status) + + prepared.pop("analysis") + name = prepared.pop("name") + try: + await self.ds.add_query(db.name, name, replace=False, **prepared) + except sqlite3.IntegrityError as ex: + if not is_json and isinstance(query_data, dict): + return await self._error_response(request, db, query_data, str(ex), 400) + return _error([str(ex)], 400) + + query = await self.ds.get_query(db.name, name) + if is_json: + return Response.json({"ok": True, "query": query}, status=201) + self.ds.add_message(request, "Query saved", self.ds.INFO) + return Response.redirect(self.ds.urls.path(self.ds.urls.table(db.name, name))) + + +class QueryDefinitionView(BaseView): + name = "query-definition" + + async def get(self, request): + db = await self.ds.resolve_database(request) + query_name = tilde_decode(request.url_vars["query"]) + query = await self.ds.get_query(db.name, query_name) + if query is None: + return _error(["Query not found: {}".format(query_name)], 404) + if not await self.ds.allowed( + action="view-query", + resource=QueryResource(db.name, query_name), + actor=request.actor, + ): + return _error(["Permission denied"], 403) + return Response.json({"ok": True, "query": query}) + + +class QueryUpdateView(BaseView): + name = "query-update" + + async def post(self, request): + db = await self.ds.resolve_database(request) + query_name = tilde_decode(request.url_vars["query"]) + existing = await self.ds.get_query(db.name, query_name) + if existing is None: + return _error(["Query not found: {}".format(query_name)], 404) + if not await self.ds.allowed( + action="update-query", + resource=QueryResource(db.name, query_name), + actor=request.actor, + ): + return _error(["Permission denied: need update-query"], 403) + + try: + data, _ = await _json_or_form_payload(request) + if not isinstance(data, dict): + raise QueryValidationError("JSON must be a dictionary") + invalid_keys = set(data) - {"update", "return"} + if invalid_keys: + raise QueryValidationError( + "Invalid keys: {}".format(", ".join(invalid_keys)) + ) + update = data.get("update") + if not isinstance(update, dict): + raise QueryValidationError("JSON must contain an update dictionary") + if "sql" in update and not await self.ds.allowed( + action="execute-sql", + resource=DatabaseResource(db.name), + actor=request.actor, + ): + raise QueryValidationError( + "Permission denied: need execute-sql", status=403 + ) + update_kwargs = await _prepare_query_update( + self.ds, request, db, existing, update + ) + except QueryValidationError as ex: + return _error([ex.message], ex.status) + + await self.ds.update_query(db.name, query_name, **update_kwargs) + if data.get("return"): + return Response.json( + { + "ok": True, + "query": await self.ds.get_query(db.name, query_name), + } + ) + return Response.json({"ok": True}) + + +class QueryDeleteView(BaseView): + name = "query-delete" + + async def post(self, request): + db = await self.ds.resolve_database(request) + query_name = tilde_decode(request.url_vars["query"]) + existing = await self.ds.get_query(db.name, query_name) + if existing is None: + return _error(["Query not found: {}".format(query_name)], 404) + if not await self.ds.allowed( + action="delete-query", + resource=QueryResource(db.name, query_name), + actor=request.actor, + ): + return _error(["Permission denied: need delete-query"], 403) + await self.ds.remove_query(db.name, query_name) + return Response.json({"ok": True}) diff --git a/docs/authentication.rst b/docs/authentication.rst index 96224ef9..f720c12f 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -496,7 +496,7 @@ Here's how to restrict access to your entire Datasette instance to just the ``"i title: My private Datasette instance allow: id: root - + .. tab:: datasette.json