database.execute_write_fn(transaction=True) parameter, closes #2277

This commit is contained in:
Simon Willison 2024-02-17 20:28:15 -08:00
commit 5e0e440f2c
3 changed files with 60 additions and 13 deletions

View file

@ -179,17 +179,25 @@ class Database:
# Threaded mode - send to write thread
return await self._send_to_write_thread(fn, isolated_connection=True)
async def execute_write_fn(self, fn, block=True):
async def execute_write_fn(self, fn, block=True, transaction=True):
if self.ds.executor is None:
# non-threaded mode
if self._write_connection is None:
self._write_connection = self.connect(write=True)
self.ds._prepare_connection(self._write_connection, self.name)
return fn(self._write_connection)
if transaction:
with self._write_connection:
return fn(self._write_connection)
else:
return fn(self._write_connection)
else:
return await self._send_to_write_thread(fn, block)
return await self._send_to_write_thread(
fn, block=block, transaction=transaction
)
async def _send_to_write_thread(self, fn, block=True, isolated_connection=False):
async def _send_to_write_thread(
self, fn, block=True, isolated_connection=False, transaction=True
):
if self._write_queue is None:
self._write_queue = queue.Queue()
if self._write_thread is None:
@ -202,7 +210,9 @@ class Database:
self._write_thread.start()
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
reply_queue = janus.Queue()
self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection))
self._write_queue.put(
WriteTask(fn, task_id, reply_queue, isolated_connection, transaction)
)
if block:
result = await reply_queue.async_q.get()
if isinstance(result, Exception):
@ -244,7 +254,11 @@ class Database:
pass
else:
try:
result = task.fn(conn)
if task.transaction:
with conn:
result = task.fn(conn)
else:
result = task.fn(conn)
except Exception as e:
sys.stderr.write("{}\n".format(e))
sys.stderr.flush()
@ -554,13 +568,14 @@ class Database:
class WriteTask:
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection")
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction")
def __init__(self, fn, task_id, reply_queue, isolated_connection):
def __init__(self, fn, task_id, reply_queue, isolated_connection, transaction):
self.fn = fn
self.task_id = task_id
self.reply_queue = reply_queue
self.isolated_connection = isolated_connection
self.transaction = transaction
class QueryInterrupted(Exception):