diff --git a/datasette/database.py b/datasette/database.py index c92d1f76..9641ae32 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -908,16 +908,18 @@ class ExecuteWriteResult: truncated = False description = cursor.description lastrowid = cursor.lastrowid - if description is not None: - if return_all: - rows = cursor.fetchall() - else: - rows = cursor.fetchmany(returning_limit + 1) - if len(rows) > returning_limit: - rows = rows[:returning_limit] - truncated = True - rowcount = cursor.rowcount - cursor.close() + try: + if description is not None: + if return_all: + rows = cursor.fetchall() + else: + rows = cursor.fetchmany(returning_limit + 1) + if len(rows) > returning_limit: + rows = rows[:returning_limit] + truncated = True + rowcount = cursor.rowcount + finally: + cursor.close() if description is not None and not return_all and truncated and rowcount == 0: rowcount = -1 return cls(rowcount, lastrowid, description, rows, truncated) diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index ddd15adc..1df1d947 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -5,7 +5,7 @@ Tests for the datasette.database.Database class import asyncio from types import SimpleNamespace from datasette.app import Datasette -from datasette.database import Database, Results, MultipleValues +from datasette.database import Database, ExecuteWriteResult, Results, MultipleValues from datasette.database import DatasetteClosedError from datasette.database import _deliver_write_result from datasette.utils.sqlite import sqlite3 @@ -628,6 +628,29 @@ async def test_execute_write_with_returning_block_false(db): ).single_value() == "Cleo" +def test_execute_write_result_closes_cursor_on_fetch_error(): + class Cursor: + description = (("id", None, None, None, None, None, None),) + lastrowid = 1 + rowcount = 0 + + def __init__(self): + self.closed = False + + def fetchmany(self, size): + raise sqlite3.DatabaseError("fetch failed") + + def close(self): + self.closed = True + + cursor = Cursor() + + with pytest.raises(sqlite3.DatabaseError): + ExecuteWriteResult.from_cursor(cursor) + + assert cursor.closed is True + + @pytest.mark.asyncio async def test_execute_write_script(db): await db.execute_write_script(