From bdfd9d548286d440b3897cab7818d30bce099694 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 31 May 2026 11:36:34 -0700 Subject: [PATCH] Fix for execute write returning issue #2762 --- datasette/database.py | 43 ++++++++++- tests/test_internals_database.py | 123 ++++++++++++++++++++++++++++++- tests/test_queries.py | 29 ++++++++ 3 files changed, 192 insertions(+), 3 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index 10417670..f1557366 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -31,6 +31,8 @@ from .inspect import inspect_hash connections = threading.local() +EXECUTE_WRITE_RETURNING_LIMIT = 10 + AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) @@ -236,11 +238,14 @@ class Database: except OSError: pass - async def execute_write(self, sql, params=None, block=True, request=None): + async def execute_write( + self, sql, params=None, block=True, request=None, return_all=False + ): self._check_not_closed() def _inner(conn): - return conn.execute(sql, params or []) + cursor = conn.execute(sql, params or []) + return ExecuteWriteResult.from_cursor(cursor, return_all=return_all) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_write_fn(_inner, block=block, request=request) @@ -877,6 +882,40 @@ class MultipleValues(Exception): pass +class ExecuteWriteResult: + def __init__(self, rowcount, lastrowid, description, rows, truncated): + self.rowcount = rowcount + self.lastrowid = lastrowid + self.description = description + self.truncated = truncated + self._rows = rows + + @classmethod + def from_cursor(cls, cursor, return_all=False): + rows = [] + truncated = False + description = cursor.description + lastrowid = cursor.lastrowid + if description is not None: + if return_all: + rows = cursor.fetchall() + else: + rows = cursor.fetchmany(EXECUTE_WRITE_RETURNING_LIMIT + 1) + if len(rows) > EXECUTE_WRITE_RETURNING_LIMIT: + rows = rows[:EXECUTE_WRITE_RETURNING_LIMIT] + truncated = True + rowcount = cursor.rowcount + cursor.close() + if description is not None and not return_all and truncated and rowcount == 0: + rowcount = -1 + return cls(rowcount, lastrowid, description, rows, truncated) + + def fetchall(self): + rows = self._rows + self._rows = [] + return rows + + class Results: def __init__(self, rows, truncated, description): self.rows = rows diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index 88f9d571..4f0aeb2c 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -469,13 +469,116 @@ async def test_view_names(db): @pytest.mark.asyncio async def test_execute_write_block_true(db): - await db.execute_write( + result = await db.execute_write( "update roadside_attractions set name = ? where pk = ?", ["Mystery!", 1] ) rows = await db.execute("select name from roadside_attractions where pk = 1") + assert result.rowcount == 1 + assert result.description is None + assert result.truncated is False + assert result.fetchall() == [] assert "Mystery!" == rows.rows[0][0] +@pytest.mark.asyncio +async def test_execute_write_with_returning(db): + await db.execute_write( + "create table write_returning (id integer primary key, name text)" + ) + result = await db.execute_write( + "insert into write_returning (name) values (?) returning id, name", + ["Cleo"], + ) + + assert result.rowcount == 1 + assert result.lastrowid == 1 + assert [column[0] for column in result.description] == ["id", "name"] + assert result.truncated is False + assert [dict(row) for row in result.fetchall()] == [{"id": 1, "name": "Cleo"}] + assert result.fetchall() == [] + assert (await db.execute("select id, name from write_returning")).dicts() == [ + {"id": 1, "name": "Cleo"} + ] + + +@pytest.mark.asyncio +async def test_execute_write_with_returning_default_limit(db): + await db.execute_write( + "create table write_returning_limit (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_limit (id) values (?)", + [(i,) for i in range(1, 21)], + ) + + result = await db.execute_write( + "update write_returning_limit set id = id returning id" + ) + + assert result.rowcount == -1 + assert result.truncated is True + assert len(result.fetchall()) == 10 + assert ( + await db.execute("select count(*) from write_returning_limit") + ).single_value() == 20 + + +@pytest.mark.asyncio +async def test_execute_write_with_returning_exact_default_limit(db): + await db.execute_write( + "create table write_returning_exact_limit (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_exact_limit (id) values (?)", + [(i,) for i in range(1, 11)], + ) + + result = await db.execute_write( + "update write_returning_exact_limit set id = id returning id" + ) + + assert result.rowcount == 10 + assert result.truncated is False + assert len(result.fetchall()) == 10 + + +@pytest.mark.asyncio +async def test_execute_write_with_returning_one_more_than_default_limit(db): + await db.execute_write( + "create table write_returning_one_more (id integer primary key)" + ) + await db.execute_write_many( + "insert into write_returning_one_more (id) values (?)", + [(i,) for i in range(1, 12)], + ) + + result = await db.execute_write( + "update write_returning_one_more set id = id returning id" + ) + + assert result.rowcount == 11 + assert result.truncated is True + assert len(result.fetchall()) == 10 + + +@pytest.mark.asyncio +async def test_execute_write_with_returning_return_all(db): + await db.execute_write("create table write_returning_all (id integer primary key)") + await db.execute_write_many( + "insert into write_returning_all (id) values (?)", + [(i,) for i in range(1, 21)], + ) + + result = await db.execute_write( + "update write_returning_all set id = id returning id", + return_all=True, + ) + + assert result.rowcount == 20 + assert result.truncated is False + assert [row["id"] for row in result.fetchall()] == list(range(1, 21)) + + @pytest.mark.asyncio async def test_execute_write_block_false(db): await db.execute_write( @@ -487,6 +590,24 @@ async def test_execute_write_block_false(db): assert "Mystery!" == rows.rows[0][0] +@pytest.mark.asyncio +async def test_execute_write_with_returning_block_false(db): + await db.execute_write( + "create table write_returning_block_false (id integer primary key, name text)" + ) + task_id = await db.execute_write( + "insert into write_returning_block_false (name) values (?) returning id", + ["Cleo"], + block=False, + ) + + assert isinstance(task_id, uuid.UUID) + time.sleep(0.1) + assert ( + await db.execute("select name from write_returning_block_false") + ).single_value() == "Cleo" + + @pytest.mark.asyncio async def test_execute_write_script(db): await db.execute_write_script( diff --git a/tests/test_queries.py b/tests/test_queries.py index 89167a1d..c75c2459 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -3002,3 +3002,32 @@ async def test_user_writable_query_execution_rechecks_table_permissions(): assert denied_response.status_code == 403 rows = (await db.execute("select name from dogs")).dicts() assert rows == [{"name": "Cleo"}] + + +@pytest.mark.asyncio +async def test_stored_write_query_with_returning(): + ds = Datasette(memory=True, default_deny=True) + ds.root_enabled = True + db = ds.add_memory_database("query_write_returning", name="data") + await db.execute_write("create table dogs (id integer primary key, name text)") + await ds.invoke_startup() + await ds.add_query( + "data", + "insert_dog", + "insert into dogs (name) values (:name) returning id, name", + is_write=True, + source="user", + owner_id="root", + ) + + response = await ds.client.post( + "/data/insert_dog?_json=1", + actor={"id": "root"}, + data={"name": "Cleo"}, + ) + + assert response.status_code == 200 + assert response.json()["ok"] is True + assert (await db.execute("select id, name from dogs")).dicts() == [ + {"id": 1, "name": "Cleo"} + ]