Fix for execute write returning issue #2762

This commit is contained in:
Simon Willison 2026-05-31 11:36:34 -07:00
commit bdfd9d5482
3 changed files with 192 additions and 3 deletions

View file

@ -31,6 +31,8 @@ from .inspect import inspect_hash
connections = threading.local()
EXECUTE_WRITE_RETURNING_LIMIT = 10
AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file"))
@ -236,11 +238,14 @@ class Database:
except OSError:
pass
async def execute_write(self, sql, params=None, block=True, request=None):
async def execute_write(
self, sql, params=None, block=True, request=None, return_all=False
):
self._check_not_closed()
def _inner(conn):
return conn.execute(sql, params or [])
cursor = conn.execute(sql, params or [])
return ExecuteWriteResult.from_cursor(cursor, return_all=return_all)
with trace("sql", database=self.name, sql=sql.strip(), params=params):
results = await self.execute_write_fn(_inner, block=block, request=request)
@ -877,6 +882,40 @@ class MultipleValues(Exception):
pass
class ExecuteWriteResult:
def __init__(self, rowcount, lastrowid, description, rows, truncated):
self.rowcount = rowcount
self.lastrowid = lastrowid
self.description = description
self.truncated = truncated
self._rows = rows
@classmethod
def from_cursor(cls, cursor, return_all=False):
rows = []
truncated = False
description = cursor.description
lastrowid = cursor.lastrowid
if description is not None:
if return_all:
rows = cursor.fetchall()
else:
rows = cursor.fetchmany(EXECUTE_WRITE_RETURNING_LIMIT + 1)
if len(rows) > EXECUTE_WRITE_RETURNING_LIMIT:
rows = rows[:EXECUTE_WRITE_RETURNING_LIMIT]
truncated = True
rowcount = cursor.rowcount
cursor.close()
if description is not None and not return_all and truncated and rowcount == 0:
rowcount = -1
return cls(rowcount, lastrowid, description, rows, truncated)
def fetchall(self):
rows = self._rows
self._rows = []
return rows
class Results:
def __init__(self, rows, truncated, description):
self.rows = rows

View file

@ -469,13 +469,116 @@ async def test_view_names(db):
@pytest.mark.asyncio
async def test_execute_write_block_true(db):
await db.execute_write(
result = await db.execute_write(
"update roadside_attractions set name = ? where pk = ?", ["Mystery!", 1]
)
rows = await db.execute("select name from roadside_attractions where pk = 1")
assert result.rowcount == 1
assert result.description is None
assert result.truncated is False
assert result.fetchall() == []
assert "Mystery!" == rows.rows[0][0]
@pytest.mark.asyncio
async def test_execute_write_with_returning(db):
await db.execute_write(
"create table write_returning (id integer primary key, name text)"
)
result = await db.execute_write(
"insert into write_returning (name) values (?) returning id, name",
["Cleo"],
)
assert result.rowcount == 1
assert result.lastrowid == 1
assert [column[0] for column in result.description] == ["id", "name"]
assert result.truncated is False
assert [dict(row) for row in result.fetchall()] == [{"id": 1, "name": "Cleo"}]
assert result.fetchall() == []
assert (await db.execute("select id, name from write_returning")).dicts() == [
{"id": 1, "name": "Cleo"}
]
@pytest.mark.asyncio
async def test_execute_write_with_returning_default_limit(db):
await db.execute_write(
"create table write_returning_limit (id integer primary key)"
)
await db.execute_write_many(
"insert into write_returning_limit (id) values (?)",
[(i,) for i in range(1, 21)],
)
result = await db.execute_write(
"update write_returning_limit set id = id returning id"
)
assert result.rowcount == -1
assert result.truncated is True
assert len(result.fetchall()) == 10
assert (
await db.execute("select count(*) from write_returning_limit")
).single_value() == 20
@pytest.mark.asyncio
async def test_execute_write_with_returning_exact_default_limit(db):
await db.execute_write(
"create table write_returning_exact_limit (id integer primary key)"
)
await db.execute_write_many(
"insert into write_returning_exact_limit (id) values (?)",
[(i,) for i in range(1, 11)],
)
result = await db.execute_write(
"update write_returning_exact_limit set id = id returning id"
)
assert result.rowcount == 10
assert result.truncated is False
assert len(result.fetchall()) == 10
@pytest.mark.asyncio
async def test_execute_write_with_returning_one_more_than_default_limit(db):
await db.execute_write(
"create table write_returning_one_more (id integer primary key)"
)
await db.execute_write_many(
"insert into write_returning_one_more (id) values (?)",
[(i,) for i in range(1, 12)],
)
result = await db.execute_write(
"update write_returning_one_more set id = id returning id"
)
assert result.rowcount == 11
assert result.truncated is True
assert len(result.fetchall()) == 10
@pytest.mark.asyncio
async def test_execute_write_with_returning_return_all(db):
await db.execute_write("create table write_returning_all (id integer primary key)")
await db.execute_write_many(
"insert into write_returning_all (id) values (?)",
[(i,) for i in range(1, 21)],
)
result = await db.execute_write(
"update write_returning_all set id = id returning id",
return_all=True,
)
assert result.rowcount == 20
assert result.truncated is False
assert [row["id"] for row in result.fetchall()] == list(range(1, 21))
@pytest.mark.asyncio
async def test_execute_write_block_false(db):
await db.execute_write(
@ -487,6 +590,24 @@ async def test_execute_write_block_false(db):
assert "Mystery!" == rows.rows[0][0]
@pytest.mark.asyncio
async def test_execute_write_with_returning_block_false(db):
await db.execute_write(
"create table write_returning_block_false (id integer primary key, name text)"
)
task_id = await db.execute_write(
"insert into write_returning_block_false (name) values (?) returning id",
["Cleo"],
block=False,
)
assert isinstance(task_id, uuid.UUID)
time.sleep(0.1)
assert (
await db.execute("select name from write_returning_block_false")
).single_value() == "Cleo"
@pytest.mark.asyncio
async def test_execute_write_script(db):
await db.execute_write_script(

View file

@ -3002,3 +3002,32 @@ async def test_user_writable_query_execution_rechecks_table_permissions():
assert denied_response.status_code == 403
rows = (await db.execute("select name from dogs")).dicts()
assert rows == [{"name": "Cleo"}]
@pytest.mark.asyncio
async def test_stored_write_query_with_returning():
ds = Datasette(memory=True, default_deny=True)
ds.root_enabled = True
db = ds.add_memory_database("query_write_returning", name="data")
await db.execute_write("create table dogs (id integer primary key, name text)")
await ds.invoke_startup()
await ds.add_query(
"data",
"insert_dog",
"insert into dogs (name) values (:name) returning id, name",
is_write=True,
source="user",
owner_id="root",
)
response = await ds.client.post(
"/data/insert_dog?_json=1",
actor={"id": "root"},
data={"name": "Cleo"},
)
assert response.status_code == 200
assert response.json()["ok"] is True
assert (await db.execute("select id, name from dogs")).dicts() == [
{"id": 1, "name": "Cleo"}
]