From b1f3e4368c81490c1468b1c641e02fa15771b013 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 31 May 2026 16:15:34 -0700 Subject: [PATCH] Fixes for SQL write with RETURNING (#2763) * Fix for execute write returning, closes #2762 * Fix stored write returning rowcount message * Add configurable execute_write returning limit * Return rows/truncated from execute query if it used RETURNING * INSERT ... RETURNING shows rows in /-/execute-write * Skip RETURNING tests if SQLite version does not support it Screenshot: https://github.com/simonw/datasette/issues/2762#issuecomment-4588111545 --- datasette/database.py | 57 ++++++- datasette/templates/_query_results.html | 20 +++ datasette/templates/execute_write.html | 11 ++ datasette/templates/query.html | 22 +-- datasette/utils/sqlite.py | 16 ++ datasette/views/database.py | 10 +- datasette/views/execute_write.py | 59 +++++-- docs/internals.rst | 29 +++- docs/json_api.rst | 42 ++++- tests/test_internals_database.py | 181 ++++++++++++++++++++- tests/test_queries.py | 201 ++++++++++++++++++++++++ tests/test_utils.py | 21 ++- 12 files changed, 622 insertions(+), 47 deletions(-) create mode 100644 datasette/templates/_query_results.html diff --git a/datasette/database.py b/datasette/database.py index 10417670..0a32442c 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -31,6 +31,8 @@ from .inspect import inspect_hash connections = threading.local() +EXECUTE_WRITE_RETURNING_LIMIT = 10 + AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) @@ -236,11 +238,24 @@ class Database: except OSError: pass - async def execute_write(self, sql, params=None, block=True, request=None): + async def execute_write( + self, + sql, + params=None, + block=True, + request=None, + return_all=False, + returning_limit=EXECUTE_WRITE_RETURNING_LIMIT, + ): self._check_not_closed() + if returning_limit < 0: + raise ValueError("returning_limit must be >= 0") def _inner(conn): - return conn.execute(sql, params or []) + cursor = conn.execute(sql, params or []) + return ExecuteWriteResult.from_cursor( + cursor, return_all=return_all, returning_limit=returning_limit + ) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_write_fn(_inner, block=block, request=request) @@ -877,6 +892,44 @@ class MultipleValues(Exception): pass +class ExecuteWriteResult: + def __init__(self, rowcount, lastrowid, description, rows, truncated): + self.rowcount = rowcount + self.lastrowid = lastrowid + self.description = description + self.truncated = truncated + self._rows = rows + + @classmethod + def from_cursor( + cls, cursor, return_all=False, returning_limit=EXECUTE_WRITE_RETURNING_LIMIT + ): + rows = [] + truncated = False + description = cursor.description + lastrowid = cursor.lastrowid + try: + if description is not None: + if return_all: + rows = cursor.fetchall() + else: + rows = cursor.fetchmany(returning_limit + 1) + if len(rows) > returning_limit: + rows = rows[:returning_limit] + truncated = True + rowcount = cursor.rowcount + finally: + cursor.close() + if description is not None and not return_all and truncated: + rowcount = -1 + return cls(rowcount, lastrowid, description, rows, truncated) + + def fetchall(self): + rows = self._rows + self._rows = [] + return rows + + class Results: def __init__(self, rows, truncated, description): self.rows = rows diff --git a/datasette/templates/_query_results.html b/datasette/templates/_query_results.html new file mode 100644 index 00000000..5e1e2f72 --- /dev/null +++ b/datasette/templates/_query_results.html @@ -0,0 +1,20 @@ +{% if display_rows %} +
+ + + {% for column in columns %}{% endfor %} + + + + {% for row in display_rows %} + + {% for column, td in zip(columns, row) %} + + {% endfor %} + + {% endfor %} + +
{{ column }}
{{ td }}
+{% elif show_zero_results %} +

0 results

+{% endif %} diff --git a/datasette/templates/execute_write.html b/datasette/templates/execute_write.html index 394261de..a93de3a6 100644 --- a/datasette/templates/execute_write.html +++ b/datasette/templates/execute_write.html @@ -81,6 +81,17 @@ form.sql.core input[data-execute-write-submit]:disabled {

{{ execution_message }}{% for link in execution_links %} {{ link.label }}{% endfor %}

{% endif %} +{% if execute_write_returns_rows %} +

Returned rows

+ {% if execute_write_truncated %} +

Only the first {{ "{:,}".format(execute_write_display_rows|length) }} returned rows are shown.

+ {% endif %} + {% set columns = execute_write_columns %} + {% set display_rows = execute_write_display_rows %} + {% set show_zero_results = true %} + {% include "_query_results.html" %} +{% endif %} +
{% if write_template_tables %}
diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 168a636b..8dd1037f 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -73,27 +73,9 @@ {% if display_rows %} -
- - - {% for column in columns %}{% endfor %} - - - - {% for row in display_rows %} - - {% for column, td in zip(columns, row) %} - - {% endfor %} - - {% endfor %} - -
{{ column }}
{{ td }}
-{% else %} - {% if not stored_query_write and not error %} -

0 results

- {% endif %} {% endif %} +{% set show_zero_results = not stored_query_write and not error %} +{% include "_query_results.html" %} {% include "_codemirror_foot.html" %} {% include "_sql_parameter_scripts.html" %} diff --git a/datasette/utils/sqlite.py b/datasette/utils/sqlite.py index 5a7c6c38..4743ae4c 100644 --- a/datasette/utils/sqlite.py +++ b/datasette/utils/sqlite.py @@ -13,6 +13,7 @@ if hasattr(sqlite3, "enable_callback_tracebacks"): sqlite3.enable_callback_tracebacks(True) _cached_sqlite_version = None +_cached_supports_returning = None SQLiteTableType = Literal["table", "view", "virtual", "shadow"] _VIRTUAL_TABLE_MODULE_RE = re.compile( r"\bCREATE\s+VIRTUAL\s+TABLE\b.*?\bUSING\s+([^\s(]+)", @@ -59,6 +60,21 @@ def supports_generated_columns(): return sqlite_version() >= (3, 31, 0) +def supports_returning(): + global _cached_supports_returning + if _cached_supports_returning is None: + conn = sqlite3.connect(":memory:") + try: + conn.execute("create table t (id integer primary key)") + conn.execute("insert into t default values returning id").fetchone() + _cached_supports_returning = True + except sqlite3.DatabaseError: + _cached_supports_returning = False + finally: + conn.close() + return _cached_supports_returning + + def sqlite_table_type( conn, table: str, diff --git a/datasette/views/database.py b/datasette/views/database.py index d6c88962..a1647ca9 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -528,12 +528,14 @@ class QueryView(View): message = "Error running on_success_message_sql: {}".format(ex) message_type = datasette.ERROR if not message: - message = ( - stored_query.on_success_message - or "Query executed, {} row{} affected".format( + if stored_query.on_success_message: + message = stored_query.on_success_message + elif cursor.rowcount == -1: + message = "Query executed" + else: + message = "Query executed, {} row{} affected".format( cursor.rowcount, "" if cursor.rowcount == 1 else "s" ) - ) redirect_url = stored_query.on_success_redirect ok = True diff --git a/datasette/views/execute_write.py b/datasette/views/execute_write.py index cff20847..c5d55b80 100644 --- a/datasette/views/execute_write.py +++ b/datasette/views/execute_write.py @@ -6,6 +6,7 @@ from datasette.utils import sqlite3 from datasette.utils.asgi import Response from .base import BaseView, _error +from .database import display_rows as display_query_rows from .query_helpers import ( QueryValidationError, _analysis_is_write, @@ -221,10 +222,16 @@ class ExecuteWriteView(BaseView): execution_message=None, execution_links=None, execution_ok=None, + execute_write_returns_rows=False, + execute_write_columns=None, + execute_write_display_rows=None, + execute_write_truncated=False, status=200, ): parameter_values = parameter_values or {} execution_links = execution_links or [] + execute_write_columns = execute_write_columns or [] + execute_write_display_rows = execute_write_display_rows or [] parameter_names = [] analysis_rows = [] table_columns = await _table_columns(self.ds, db.name) @@ -284,6 +291,10 @@ class ExecuteWriteView(BaseView): "execution_message": execution_message, "execution_links": execution_links, "execution_ok": execution_ok, + "execute_write_returns_rows": execute_write_returns_rows, + "execute_write_columns": execute_write_columns, + "execute_write_display_rows": execute_write_display_rows, + "execute_write_truncated": execute_write_truncated, "execute_disabled": bool(execute_disabled_reason), "execute_disabled_reason": execute_disabled_reason, "table_columns": table_columns, @@ -355,11 +366,13 @@ class ExecuteWriteView(BaseView): status=ex.status, ) + wants_json = _wants_json(request, is_json, data) try: - cursor = await db.execute_write(sql, params, request=request) + execute_write_kwargs = {"request": request} + cursor = await db.execute_write(sql, params, **execute_write_kwargs) except sqlite3.DatabaseError as ex: message = str(ex) - if _wants_json(request, is_json, data): + if wants_json: return _block_framing(_error([message], 400)) return await self._render_form( request, @@ -378,17 +391,19 @@ class ExecuteWriteView(BaseView): message = "Query executed, {} row{} affected".format( cursor.rowcount, "" if cursor.rowcount == 1 else "s" ) - if _wants_json(request, is_json, data): - return _block_framing( - Response.json( - { - "ok": True, - "message": message, - "rowcount": cursor.rowcount, - "analysis": _analysis_rows(analysis), - } - ) - ) + if wants_json: + data = { + "ok": True, + "message": message, + "rowcount": cursor.rowcount, + "rows": [], + "truncated": False, + "analysis": _analysis_rows(analysis), + } + if cursor.description is not None: + data["rows"] = [dict(row) for row in cursor.fetchall()] + data["truncated"] = cursor.truncated + return _block_framing(Response.json(data)) inserted_row_url = await _inserted_row_url(self.ds, db, analysis, cursor) execution_links = ( @@ -396,6 +411,20 @@ class ExecuteWriteView(BaseView): if inserted_row_url else [] ) + execute_write_returns_rows = cursor.description is not None + execute_write_columns = [] + execute_write_display_rows = [] + if execute_write_returns_rows: + execute_write_columns = [ + description[0] for description in cursor.description + ] + execute_write_display_rows = await display_query_rows( + self.ds, + db.name, + request, + cursor.fetchall(), + execute_write_columns, + ) return await self._render_form( request, db, @@ -405,6 +434,10 @@ class ExecuteWriteView(BaseView): execution_message=message, execution_links=execution_links, execution_ok=True, + execute_write_returns_rows=execute_write_returns_rows, + execute_write_columns=execute_write_columns, + execute_write_display_rows=execute_write_display_rows, + execute_write_truncated=cursor.truncated, ) diff --git a/docs/internals.rst b/docs/internals.rst index 4980ee8b..f269155a 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1928,8 +1928,8 @@ Example usage: .. _database_execute_write: -await db.execute_write(sql, params=None, block=True) ----------------------------------------------------- +await db.execute_write(sql, params=None, block=True, request=None, return_all=False, returning_limit=10) +-------------------------------------------------------------------------------------------------------- SQLite only allows one database connection to write at a time. Datasette handles this for you by maintaining a queue of writes to be executed against a given database. Plugins can submit write operations to this queue and they will be executed in the order in which they are received. @@ -1937,7 +1937,30 @@ This method can be used to queue up a non-SELECT SQL query to be executed agains You can pass additional SQL parameters as a tuple or dictionary. -The method will block until the operation is completed, and the return value will be the return from calling ``conn.execute(...)`` using the underlying ``sqlite3`` Python library. +The optional ``request=`` argument is used internally by Datasette to pass request context to :ref:`write_wrapper plugin hooks `. + +The method will block until the operation is completed, and the return value will be an ``ExecuteWriteResult`` object. This imitates a subset of the ``sqlite3.Cursor`` object: + +``.rowcount`` + The number of rows modified by the statement, or ``-1`` if that number is unavailable. + +``.lastrowid`` + The row ID of the last modified row, as returned by ``sqlite3.Cursor.lastrowid``. + +``.description`` + The same column metadata exposed by Python's `sqlite3.Cursor.description `__: one tuple per returned column, or ``None`` if the statement does not return rows. + +``.truncated`` + ``True`` if the statement returned more rows than ``returning_limit``. + +``.fetchall()`` + Returns any rows buffered by Datasette from the statement, such as rows from SQLite's ``RETURNING`` clause. This may be limited by ``returning_limit`` unless ``return_all=True`` was used. This method empties the buffer, so calling it again will return an empty list. + +SQLite statements using ``RETURNING`` must have their rows consumed before the transaction can commit. Datasette will fetch up to ``returning_limit + 1`` rows before committing, store up to ``returning_limit`` rows on the result object and set ``.truncated`` if there were more. The default ``returning_limit`` is ``10``. + +When ``.truncated`` is ``True``, ``.rowcount`` will be ``-1``. SQLite only reports the final row count for a ``RETURNING`` statement after every returned row has been fetched, and Datasette has deliberately stopped fetching rows after ``returning_limit`` to avoid buffering a potentially large result in memory. + +If you need to retrieve every row returned by a statement, pass ``return_all=True``. This will buffer all returned rows in memory before committing. If you pass ``block=False`` this behavior changes to "fire and forget" - queries will be added to the write queue and executed in a separate thread while your code can continue to do other things. The method will return a UUID representing the queued task. diff --git a/docs/json_api.rst b/docs/json_api.rst index 4bd76717..65031bf4 100644 --- a/docs/json_api.rst +++ b/docs/json_api.rst @@ -554,7 +554,8 @@ Datasette analyzes the SQL before executing it. The actor must have ``execute-wr Unsupported SQL operations are rejected by default. ``VACUUM`` is not allowed in arbitrary write SQL, and writes to SQLite virtual tables or shadow tables are rejected. SQL functions are allowed and are not separately restricted by Datasette permissions. -A successful response includes a message, the SQLite ``rowcount`` and a summary of the operations that were executed: +A successful response includes a message, the SQLite ``rowcount``, a ``"rows"`` +list, a ``"truncated"`` flag and a summary of the operations that were executed: The shape of the ``"analysis"`` block is not yet considered a stable API and may change in future Datasette releases. @@ -564,6 +565,8 @@ The shape of the ``"analysis"`` block is not yet considered a stable API and may "ok": true, "message": "Query executed, 1 row affected", "rowcount": 1, + "rows": [], + "truncated": false, "analysis": [ { "operation": "insert", @@ -577,6 +580,43 @@ The shape of the ``"analysis"`` block is not yet considered a stable API and may If SQLite reports ``-1`` for the row count, the message will be ``"Query executed"``. +For most write statements ``"rows"`` will be an empty list and ``"truncated"`` +will be ``false``. If the SQL uses SQLite's ``RETURNING`` clause, ``"rows"`` +will contain returned rows using the same default representation as table and +query JSON responses. ``"truncated"`` indicates if more rows were returned than +the execute-write returning row limit, which defaults to 10: + +.. code-block:: json + + { + "ok": true, + "message": "Query executed, 1 row affected", + "rowcount": 1, + "rows": [ + { + "id": 1, + "name": "Cleo" + } + ], + "truncated": false, + "analysis": [ + { + "operation": "insert", + "database": "data", + "table": "dogs", + "required_permission": "insert-row, update-row, delete-row", + "source": null + }, + { + "operation": "read", + "database": "data", + "table": "dogs", + "required_permission": "view-table", + "source": null + } + ] + } + Errors use the standard Datasette error format: .. code-block:: json diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index 88f9d571..bb209649 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -5,15 +5,19 @@ Tests for the datasette.database.Database class import asyncio from types import SimpleNamespace from datasette.app import Datasette -from datasette.database import Database, Results, MultipleValues +from datasette.database import Database, ExecuteWriteResult, Results, MultipleValues from datasette.database import DatasetteClosedError from datasette.database import _deliver_write_result -from datasette.utils.sqlite import sqlite3 +from datasette.utils.sqlite import sqlite3, supports_returning from datasette.utils import Column import pytest import time import uuid +requires_sqlite_returning = pytest.mark.skipif( + not supports_returning(), reason="SQLite does not support RETURNING" +) + @pytest.fixture def db(app_client): @@ -469,13 +473,142 @@ async def test_view_names(db): @pytest.mark.asyncio async def test_execute_write_block_true(db): - await db.execute_write( + result = await db.execute_write( "update roadside_attractions set name = ? where pk = ?", ["Mystery!", 1] ) rows = await db.execute("select name from roadside_attractions where pk = 1") + assert result.rowcount == 1 + assert result.description is None + assert result.truncated is False + assert result.fetchall() == [] assert "Mystery!" == rows.rows[0][0] +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning(db): + await db.execute_write( + "create table write_returning (id integer primary key, name text)" + ) + result = await db.execute_write( + "insert into write_returning (name) values (?) returning id, name", + ["Cleo"], + ) + + assert result.rowcount == 1 + assert result.lastrowid == 1 + assert [column[0] for column in result.description] == ["id", "name"] + assert result.truncated is False + assert [dict(row) for row in result.fetchall()] == [{"id": 1, "name": "Cleo"}] + assert result.fetchall() == [] + assert (await db.execute("select id, name from write_returning")).dicts() == [ + {"id": 1, "name": "Cleo"} + ] + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_default_limit(db): + await db.execute_write( + "create table write_returning_limit (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_limit (id) values (?)", + [(i,) for i in range(1, 21)], + ) + + result = await db.execute_write( + "update write_returning_limit set id = id returning id" + ) + + assert result.rowcount == -1 + assert result.truncated is True + assert len(result.fetchall()) == 10 + assert ( + await db.execute("select count(*) from write_returning_limit") + ).single_value() == 20 + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_custom_limit(db): + await db.execute_write( + "create table write_returning_custom (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_custom (id) values (?)", + [(i,) for i in range(1, 6)], + ) + + result = await db.execute_write( + "update write_returning_custom set id = id returning id", + returning_limit=2, + ) + + assert result.rowcount == -1 + assert result.truncated is True + assert [row["id"] for row in result.fetchall()] == [1, 2] + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_exact_default_limit(db): + await db.execute_write( + "create table write_returning_exact_limit (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_exact_limit (id) values (?)", + [(i,) for i in range(1, 11)], + ) + + result = await db.execute_write( + "update write_returning_exact_limit set id = id returning id" + ) + + assert result.rowcount == 10 + assert result.truncated is False + assert len(result.fetchall()) == 10 + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_one_more_than_default_limit(db): + await db.execute_write( + "create table write_returning_one_more (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_one_more (id) values (?)", + [(i,) for i in range(1, 12)], + ) + + result = await db.execute_write( + "update write_returning_one_more set id = id returning id" + ) + + assert result.rowcount == -1 + assert result.truncated is True + assert len(result.fetchall()) == 10 + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_return_all(db): + await db.execute_write("create table write_returning_all (id integer primary key)") + await db.execute_write_many( + "insert into write_returning_all (id) values (?)", + [(i,) for i in range(1, 21)], + ) + + result = await db.execute_write( + "update write_returning_all set id = id returning id", + return_all=True, + ) + + assert result.rowcount == 20 + assert result.truncated is False + assert [row["id"] for row in result.fetchall()] == list(range(1, 21)) + + @pytest.mark.asyncio async def test_execute_write_block_false(db): await db.execute_write( @@ -487,6 +620,48 @@ async def test_execute_write_block_false(db): assert "Mystery!" == rows.rows[0][0] +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_with_returning_block_false(db): + await db.execute_write( + "create table write_returning_block_false (id integer primary key, name text)" + ) + task_id = await db.execute_write( + "insert into write_returning_block_false (name) values (?) returning id", + ["Cleo"], + block=False, + ) + + assert isinstance(task_id, uuid.UUID) + time.sleep(0.1) + assert ( + await db.execute("select name from write_returning_block_false") + ).single_value() == "Cleo" + + +def test_execute_write_result_closes_cursor_on_fetch_error(): + class Cursor: + description = (("id", None, None, None, None, None, None),) + lastrowid = 1 + rowcount = 0 + + def __init__(self): + self.closed = False + + def fetchmany(self, size): + raise sqlite3.DatabaseError("fetch failed") + + def close(self): + self.closed = True + + cursor = Cursor() + + with pytest.raises(sqlite3.DatabaseError): + ExecuteWriteResult.from_cursor(cursor) + + assert cursor.closed is True + + @pytest.mark.asyncio async def test_execute_write_script(db): await db.execute_write_script( diff --git a/tests/test_queries.py b/tests/test_queries.py index 89167a1d..cef06d7f 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -8,6 +8,11 @@ from datasette.app import Datasette from datasette.resources import DatabaseResource, QueryResource from datasette.stored_queries import StoredQuery, StoredQueryPage from datasette.utils.asgi import Forbidden +from datasette.utils.sqlite import supports_returning + +requires_sqlite_returning = pytest.mark.skipif( + not supports_returning(), reason="SQLite does not support RETURNING" +) def _template_option_attributes(html, table): @@ -1884,10 +1889,144 @@ async def test_execute_write_post_requires_database_and_table_permissions(): assert allowed.status_code == 200 assert allowed.json()["ok"] is True assert allowed.json()["rowcount"] == 1 + assert allowed.json()["rows"] == [] + assert allowed.json()["truncated"] is False assert allowed.json()["analysis"][0]["operation"] == "insert" assert (await db.execute("select name from dogs")).first()[0] == "Cleo" +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_json_includes_returning_rows(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("execute_write_returning_json", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await ds.invoke_startup() + + response = await ds.client.post( + "/data/-/execute-write", + actor={"id": "root"}, + json={ + "sql": "insert into dogs (name) values (:name) returning id, name", + "params": {"name": "Cleo"}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["ok"] is True + assert data["message"] == "Query executed, 1 row affected" + assert data["rowcount"] == 1 + assert data["rows"] == [{"id": 1, "name": "Cleo"}] + assert data["truncated"] is False + assert [row["operation"] for row in data["analysis"]] == ["insert", "read"] + assert (await db.execute("select id, name from dogs")).dicts() == [ + {"id": 1, "name": "Cleo"} + ] + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_json_returning_rows_can_be_truncated(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("execute_write_returning_json_truncated", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + for index in range(1, 12): + await db.execute_write( + "insert into dogs (name) values (?)", ["Dog {}".format(index)] + ) + await ds.invoke_startup() + + response = await ds.client.post( + "/data/-/execute-write", + actor={"id": "root"}, + json={"sql": "update dogs set name = name || '!' returning id, name"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["ok"] is True + assert data["message"] == "Query executed" + assert data["rowcount"] == -1 + assert data["rows"] == [ + {"id": index, "name": "Dog {}!".format(index)} for index in range(1, 11) + ] + assert data["truncated"] is True + assert (await db.execute("select count(*) from dogs where name like '%!'")).first()[ + 0 + ] == 11 + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_html_displays_returning_rows(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("execute_write_returning_html", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await ds.invoke_startup() + + response = await ds.client.post( + "/data/-/execute-write", + actor={"id": "root"}, + data={ + "sql": "insert into dogs (name) values (:name) returning id, name", + "name": "Cleo", + }, + ) + non_returning_response = await ds.client.post( + "/data/-/execute-write", + actor={"id": "root"}, + data={"sql": "insert into dogs (name) values ('Pancakes')"}, + ) + + assert response.status_code == 200 + assert "Query executed, 1 row affected" in response.text + assert "

Returned rows

" in response.text + assert '' in response.text + assert '' in response.text + assert '' in response.text + assert '' in response.text + assert '' in response.text + + assert non_returning_response.status_code == 200 + assert "Query executed, 1 row affected" in non_returning_response.text + assert "

Returned rows

" not in non_returning_response.text + assert '

0 results

' not in non_returning_response.text + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_execute_write_html_returning_rows_can_be_truncated(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("execute_write_returning_html_truncated", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + for index in range(1, 12): + await db.execute_write( + "insert into dogs (name) values (?)", ["Dog {}".format(index)] + ) + await ds.invoke_startup() + + response = await ds.client.post( + "/data/-/execute-write", + actor={"id": "root"}, + data={"sql": "update dogs set name = name || '!' returning id, name"}, + ) + + assert response.status_code == 200 + assert "

Returned rows

" in response.text + assert "Only the first 10 returned rows are shown." in response.text + assert '' in response.text + assert '' in response.text + assert '' in response.text + assert '' in response.text + assert '' not in response.text + assert '' not in response.text + + @pytest.mark.parametrize( "database_name, sql", ( @@ -3002,3 +3141,65 @@ async def test_user_writable_query_execution_rechecks_table_permissions(): assert denied_response.status_code == 403 rows = (await db.execute("select name from dogs")).dicts() assert rows == [{"name": "Cleo"}] + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_stored_write_query_with_returning(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("query_write_returning", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await ds.invoke_startup() + await ds.add_query( + "data", + "insert_dog", + "insert into dogs (name) values (:name) returning id, name", + is_write=True, + source="user", + owner_id="root", + ) + + response = await ds.client.post( + "/data/insert_dog?_json=1", + actor={"id": "root"}, + data={"name": "Cleo"}, + ) + + assert response.status_code == 200 + assert response.json()["ok"] is True + assert (await db.execute("select id, name from dogs")).dicts() == [ + {"id": 1, "name": "Cleo"} + ] + + +@pytest.mark.asyncio +@requires_sqlite_returning +async def test_stored_write_query_with_truncated_returning_message(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("query_write_truncated_returning", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await db.execute_write_many( + "insert into dogs (name) values (?)", + [("Cleo",) for _ in range(20)], + ) + await ds.invoke_startup() + await ds.add_query( + "data", + "update_dogs", + "update dogs set name = name returning id", + is_write=True, + source="user", + owner_id="root", + ) + + response = await ds.client.post( + "/data/update_dogs?_json=1", + actor={"id": "root"}, + data={}, + ) + + assert response.status_code == 200 + assert response.json()["ok"] is True + assert response.json()["message"] == "Query executed" diff --git a/tests/test_utils.py b/tests/test_utils.py index f6de3b46..64607244 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,12 @@ Tests for various datasette helper functions. from datasette.app import Datasette from datasette import utils from datasette.utils.asgi import Request -from datasette.utils.sqlite import sqlite3, sqlite_hidden_table_names, sqlite_table_type +from datasette.utils.sqlite import ( + sqlite3, + sqlite_hidden_table_names, + sqlite_table_type, + supports_returning, +) import json import os import pathlib @@ -226,6 +231,20 @@ def test_detect_fts_different_table_names(table): conn.close() +def test_supports_returning(): + conn = utils.sqlite3.connect(":memory:") + try: + conn.execute("create table t (id integer primary key)") + conn.execute("insert into t default values returning id").fetchone() + expected = True + except sqlite3.DatabaseError: + expected = False + finally: + conn.close() + + assert supports_returning() is expected + + @pytest.mark.parametrize("use_fallback", (False, True)) def test_sqlite_table_type_detects_virtual_and_shadow_tables(monkeypatch, use_fallback): if use_fallback:
idname1Cleo1Dog 1!10Dog 10!11Dog 11!