diff --git a/datasette/database.py b/datasette/database.py index 10417670..0a32442c 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -31,6 +31,8 @@ from .inspect import inspect_hash connections = threading.local() +EXECUTE_WRITE_RETURNING_LIMIT = 10 + AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) @@ -236,11 +238,24 @@ class Database: except OSError: pass - async def execute_write(self, sql, params=None, block=True, request=None): + async def execute_write( + self, + sql, + params=None, + block=True, + request=None, + return_all=False, + returning_limit=EXECUTE_WRITE_RETURNING_LIMIT, + ): self._check_not_closed() + if returning_limit < 0: + raise ValueError("returning_limit must be >= 0") def _inner(conn): - return conn.execute(sql, params or []) + cursor = conn.execute(sql, params or []) + return ExecuteWriteResult.from_cursor( + cursor, return_all=return_all, returning_limit=returning_limit + ) with trace("sql", database=self.name, sql=sql.strip(), params=params): results = await self.execute_write_fn(_inner, block=block, request=request) @@ -877,6 +892,44 @@ class MultipleValues(Exception): pass +class ExecuteWriteResult: + def __init__(self, rowcount, lastrowid, description, rows, truncated): + self.rowcount = rowcount + self.lastrowid = lastrowid + self.description = description + self.truncated = truncated + self._rows = rows + + @classmethod + def from_cursor( + cls, cursor, return_all=False, returning_limit=EXECUTE_WRITE_RETURNING_LIMIT + ): + rows = [] + truncated = False + description = cursor.description + lastrowid = cursor.lastrowid + try: + if description is not None: + if return_all: + rows = cursor.fetchall() + else: + rows = cursor.fetchmany(returning_limit + 1) + if len(rows) > returning_limit: + rows = rows[:returning_limit] + truncated = True + rowcount = cursor.rowcount + finally: + cursor.close() + if description is not None and not return_all and truncated: + rowcount = -1 + return cls(rowcount, lastrowid, description, rows, truncated) + + def fetchall(self): + rows = self._rows + self._rows = [] + return rows + + class Results: def __init__(self, rows, truncated, description): self.rows = rows diff --git a/datasette/templates/_query_results.html b/datasette/templates/_query_results.html new file mode 100644 index 00000000..5e1e2f72 --- /dev/null +++ b/datasette/templates/_query_results.html @@ -0,0 +1,20 @@ +{% if display_rows %} +
| {{ column }} | {% endfor %} +
|---|
| {{ td }} | + {% endfor %} +
0 results
+{% endif %} diff --git a/datasette/templates/execute_write.html b/datasette/templates/execute_write.html index 394261de..a93de3a6 100644 --- a/datasette/templates/execute_write.html +++ b/datasette/templates/execute_write.html @@ -81,6 +81,17 @@ form.sql.core input[data-execute-write-submit]:disabled { {% endif %} +{% if execute_write_returns_rows %} +