mirror of
https://github.com/simonw/datasette.git
synced 2026-06-13 04:27:00 +02:00
parent
44e17fa3db
commit
1e81be99e4
2 changed files with 38 additions and 6 deletions
|
|
@ -239,13 +239,23 @@ class Database:
|
|||
pass
|
||||
|
||||
async def execute_write(
|
||||
self, sql, params=None, block=True, request=None, return_all=False
|
||||
self,
|
||||
sql,
|
||||
params=None,
|
||||
block=True,
|
||||
request=None,
|
||||
return_all=False,
|
||||
returning_limit=EXECUTE_WRITE_RETURNING_LIMIT,
|
||||
):
|
||||
self._check_not_closed()
|
||||
if returning_limit < 0:
|
||||
raise ValueError("returning_limit must be >= 0")
|
||||
|
||||
def _inner(conn):
|
||||
cursor = conn.execute(sql, params or [])
|
||||
return ExecuteWriteResult.from_cursor(cursor, return_all=return_all)
|
||||
return ExecuteWriteResult.from_cursor(
|
||||
cursor, return_all=return_all, returning_limit=returning_limit
|
||||
)
|
||||
|
||||
with trace("sql", database=self.name, sql=sql.strip(), params=params):
|
||||
results = await self.execute_write_fn(_inner, block=block, request=request)
|
||||
|
|
@ -891,7 +901,9 @@ class ExecuteWriteResult:
|
|||
self._rows = rows
|
||||
|
||||
@classmethod
|
||||
def from_cursor(cls, cursor, return_all=False):
|
||||
def from_cursor(
|
||||
cls, cursor, return_all=False, returning_limit=EXECUTE_WRITE_RETURNING_LIMIT
|
||||
):
|
||||
rows = []
|
||||
truncated = False
|
||||
description = cursor.description
|
||||
|
|
@ -900,9 +912,9 @@ class ExecuteWriteResult:
|
|||
if return_all:
|
||||
rows = cursor.fetchall()
|
||||
else:
|
||||
rows = cursor.fetchmany(EXECUTE_WRITE_RETURNING_LIMIT + 1)
|
||||
if len(rows) > EXECUTE_WRITE_RETURNING_LIMIT:
|
||||
rows = rows[:EXECUTE_WRITE_RETURNING_LIMIT]
|
||||
rows = cursor.fetchmany(returning_limit + 1)
|
||||
if len(rows) > returning_limit:
|
||||
rows = rows[:returning_limit]
|
||||
truncated = True
|
||||
rowcount = cursor.rowcount
|
||||
cursor.close()
|
||||
|
|
|
|||
|
|
@ -523,6 +523,26 @@ async def test_execute_write_with_returning_default_limit(db):
|
|||
).single_value() == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_write_with_returning_custom_limit(db):
|
||||
await db.execute_write(
|
||||
"create table write_returning_custom (id integer primary key)"
|
||||
)
|
||||
await db.execute_write_many(
|
||||
"insert into write_returning_custom (id) values (?)",
|
||||
[(i,) for i in range(1, 6)],
|
||||
)
|
||||
|
||||
result = await db.execute_write(
|
||||
"update write_returning_custom set id = id returning id",
|
||||
returning_limit=2,
|
||||
)
|
||||
|
||||
assert result.rowcount == -1
|
||||
assert result.truncated is True
|
||||
assert [row["id"] for row in result.fetchall()] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_write_with_returning_exact_default_limit(db):
|
||||
await db.execute_write(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue