.execute_write() and .execute_write_fn() methods on Database (#683)

Closes #682.
This commit is contained in:
Simon Willison 2020-02-24 20:45:08 -08:00 committed by GitHub
commit a093c5f79f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 282 additions and 95 deletions

View file

@ -1,7 +1,10 @@
import asyncio
import contextlib
from pathlib import Path
import janus
import queue
import threading
import uuid
from .tracer import trace
from .utils import (
@ -30,6 +33,8 @@ class Database:
self.hash = None
self.cached_size = None
self.cached_table_counts = None
self._write_thread = None
self._write_queue = None
if not self.is_mutable:
p = Path(path)
self.hash = inspect_hash(p)
@ -41,18 +46,60 @@ class Database:
for key, value in self.ds.inspect_data[self.name]["tables"].items()
}
def connect(self):
def connect(self, write=False):
if self.is_memory:
return sqlite3.connect(":memory:")
# mode=ro or immutable=1?
if self.is_mutable:
qs = "mode=ro"
qs = "?mode=ro"
else:
qs = "immutable=1"
qs = "?immutable=1"
assert not (write and not self.is_mutable)
if write:
qs = ""
return sqlite3.connect(
"file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False
"file:{}{}".format(self.path, qs), uri=True, check_same_thread=False
)
async def execute_write(self, sql, params=None, block=False):
def _inner(conn):
with conn:
return conn.execute(sql, params or [])
return await self.execute_write_fn(_inner, block=block)
async def execute_write_fn(self, fn, block=False):
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
if self._write_queue is None:
self._write_queue = queue.Queue()
if self._write_thread is None:
self._write_thread = threading.Thread(
target=self._execute_writes, daemon=True
)
self._write_thread.start()
reply_queue = janus.Queue()
self._write_queue.put(WriteTask(fn, task_id, reply_queue))
if block:
result = await reply_queue.async_q.get()
if isinstance(result, Exception):
raise result
else:
return result
else:
return task_id
def _execute_writes(self):
# Infinite looping thread that protects the single write connection
# to this database
conn = self.connect(write=True)
while True:
task = self._write_queue.get()
try:
result = task.fn(conn)
except Exception as e:
result = e
task.reply_queue.sync_q.put(result)
async def execute_against_connection_in_thread(self, fn):
def in_thread():
conn = getattr(connections, self.name, None)
@ -326,3 +373,12 @@ class Database:
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<Database: {}{}>".format(self.name, tags_str)
class WriteTask:
__slots__ = ("fn", "task_id", "reply_queue")
def __init__(self, fn, task_id, reply_queue):
self.fn = fn
self.task_id = task_id
self.reply_queue = reply_queue