mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
.execute_write() and .execute_write_fn() methods on Database (#683)
Closes #682.
This commit is contained in:
parent
411056c4c4
commit
a093c5f79f
8 changed files with 282 additions and 95 deletions
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
import janus
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from .tracer import trace
|
||||
from .utils import (
|
||||
|
|
@ -30,6 +33,8 @@ class Database:
|
|||
self.hash = None
|
||||
self.cached_size = None
|
||||
self.cached_table_counts = None
|
||||
self._write_thread = None
|
||||
self._write_queue = None
|
||||
if not self.is_mutable:
|
||||
p = Path(path)
|
||||
self.hash = inspect_hash(p)
|
||||
|
|
@ -41,18 +46,60 @@ class Database:
|
|||
for key, value in self.ds.inspect_data[self.name]["tables"].items()
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
def connect(self, write=False):
|
||||
if self.is_memory:
|
||||
return sqlite3.connect(":memory:")
|
||||
# mode=ro or immutable=1?
|
||||
if self.is_mutable:
|
||||
qs = "mode=ro"
|
||||
qs = "?mode=ro"
|
||||
else:
|
||||
qs = "immutable=1"
|
||||
qs = "?immutable=1"
|
||||
assert not (write and not self.is_mutable)
|
||||
if write:
|
||||
qs = ""
|
||||
return sqlite3.connect(
|
||||
"file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False
|
||||
"file:{}{}".format(self.path, qs), uri=True, check_same_thread=False
|
||||
)
|
||||
|
||||
async def execute_write(self, sql, params=None, block=False):
|
||||
def _inner(conn):
|
||||
with conn:
|
||||
return conn.execute(sql, params or [])
|
||||
|
||||
return await self.execute_write_fn(_inner, block=block)
|
||||
|
||||
async def execute_write_fn(self, fn, block=False):
|
||||
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
|
||||
if self._write_queue is None:
|
||||
self._write_queue = queue.Queue()
|
||||
if self._write_thread is None:
|
||||
self._write_thread = threading.Thread(
|
||||
target=self._execute_writes, daemon=True
|
||||
)
|
||||
self._write_thread.start()
|
||||
reply_queue = janus.Queue()
|
||||
self._write_queue.put(WriteTask(fn, task_id, reply_queue))
|
||||
if block:
|
||||
result = await reply_queue.async_q.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
return task_id
|
||||
|
||||
def _execute_writes(self):
|
||||
# Infinite looping thread that protects the single write connection
|
||||
# to this database
|
||||
conn = self.connect(write=True)
|
||||
while True:
|
||||
task = self._write_queue.get()
|
||||
try:
|
||||
result = task.fn(conn)
|
||||
except Exception as e:
|
||||
result = e
|
||||
task.reply_queue.sync_q.put(result)
|
||||
|
||||
async def execute_against_connection_in_thread(self, fn):
|
||||
def in_thread():
|
||||
conn = getattr(connections, self.name, None)
|
||||
|
|
@ -326,3 +373,12 @@ class Database:
|
|||
if tags:
|
||||
tags_str = " ({})".format(", ".join(tags))
|
||||
return "<Database: {}{}>".format(self.name, tags_str)
|
||||
|
||||
|
||||
class WriteTask:
|
||||
__slots__ = ("fn", "task_id", "reply_queue")
|
||||
|
||||
def __init__(self, fn, task_id, reply_queue):
|
||||
self.fn = fn
|
||||
self.task_id = task_id
|
||||
self.reply_queue = reply_queue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue