mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
parent
a9909c29cc
commit
32cbfd2acd
2 changed files with 156 additions and 0 deletions
|
|
@ -327,3 +327,94 @@ class Database:
|
||||||
if tags:
|
if tags:
|
||||||
tags_str = " ({})".format(", ".join(tags))
|
tags_str = " ({})".format(", ".join(tags))
|
||||||
return "<Database: {}{}>".format(self.name, tags_str)
|
return "<Database: {}{}>".format(self.name, tags_str)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionTimeoutError(ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Connection:
|
||||||
|
def __init__(self, name, connect_args=None, connect_kwargs=None):
|
||||||
|
self.name = name
|
||||||
|
self.connect_args = connect_args or tuple()
|
||||||
|
self.connect_kwargs = connect_kwargs or {}
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
if self._connection is None:
|
||||||
|
self._connection = sqlite3.connect(
|
||||||
|
*self.connect_args, **self.connect_kwargs
|
||||||
|
)
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{} {} ({})".format(self.name, self._connection, self.lock)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionGroup:
|
||||||
|
timeout = 2
|
||||||
|
|
||||||
|
def __init__(self, name, connect_args, connect_kwargs=None, limit=3):
|
||||||
|
self.name = name
|
||||||
|
self.connections = [
|
||||||
|
Connection(name, connect_args, connect_kwargs) for _ in range(limit)
|
||||||
|
]
|
||||||
|
self.limit = limit
|
||||||
|
self._semaphore = threading.Semaphore(value=limit)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def connection(self):
|
||||||
|
semaphore_aquired = False
|
||||||
|
reserved_connection = None
|
||||||
|
try:
|
||||||
|
semaphore_aquired = self._semaphore.acquire(timeout=self.timeout)
|
||||||
|
if not semaphore_aquired:
|
||||||
|
raise ConnectionTimeoutError(
|
||||||
|
"Timed out after {}s waiting for connection '{}'".format(
|
||||||
|
self.timeout, self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Loop through connections attempting to aquire a lock
|
||||||
|
for connection in self.connections:
|
||||||
|
lock = connection.lock
|
||||||
|
if lock.acquire(False):
|
||||||
|
# We acquired the lock! use this one
|
||||||
|
reserved_connection = connection
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If we get here, we failed to lock a connection even though
|
||||||
|
# the semaphore should have guaranteed it
|
||||||
|
raise ConnectionError(
|
||||||
|
"Failed to lock a connection despite the sempahore"
|
||||||
|
)
|
||||||
|
# We should have a connection now - yield it, then clean up locks
|
||||||
|
yield reserved_connection
|
||||||
|
finally:
|
||||||
|
reserved_connection.lock.release()
|
||||||
|
if semaphore_aquired:
|
||||||
|
self._semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
|
class Pool:
|
||||||
|
def __init__(self, databases=None, max_connections_per_database=3):
|
||||||
|
self.max_connections_per_database = max_connections_per_database
|
||||||
|
self.databases = {}
|
||||||
|
self.connection_groups = {}
|
||||||
|
for key, value in (databases or {}).items():
|
||||||
|
self.add_database(key, value)
|
||||||
|
|
||||||
|
def add_database(self, name, filepath):
|
||||||
|
self.databases[name] = filepath
|
||||||
|
self.connection_groups[name] = ConnectionGroup(
|
||||||
|
name, [filepath], limit=self.max_connections_per_database
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def connection(self, name):
|
||||||
|
with self.connection_groups[name].connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
|
||||||
65
tests/test_pool.py
Normal file
65
tests/test_pool.py
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from datasette.database import Pool
|
||||||
|
|
||||||
|
|
||||||
|
def test_lock_connection():
|
||||||
|
pool = Pool({"one": ":memory:"})
|
||||||
|
with pool.connection("one") as conn:
|
||||||
|
assert conn.lock.locked()
|
||||||
|
assert not conn.lock.locked()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_if_one_connection_is_locked():
|
||||||
|
pool = Pool({"one": ":memory:"})
|
||||||
|
connections = pool.connection_groups["one"].connections
|
||||||
|
assert 3 == len(connections)
|
||||||
|
# They should all start unlocked:
|
||||||
|
assert all(not c.lock.locked() for c in connections)
|
||||||
|
# Now lock one for the duration of this test
|
||||||
|
first_connection = connections[0]
|
||||||
|
try:
|
||||||
|
first_connection.lock.acquire()
|
||||||
|
# This should give us a different connection
|
||||||
|
with pool.connection("one") as conn:
|
||||||
|
assert conn is not first_connection
|
||||||
|
assert conn.lock.locked()
|
||||||
|
# There should be only one UNLOCKED connection now
|
||||||
|
assert 1 == len([c for c in connections if not c.lock.locked()])
|
||||||
|
finally:
|
||||||
|
first_connection.lock.release()
|
||||||
|
# At this point, all connections should be unlocked
|
||||||
|
assert 3 == len([c for c in connections if not c.lock.locked()])
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_until_connection_is_released():
|
||||||
|
# If all connections are already in use, block until one is released
|
||||||
|
pool = Pool({"one": ":memory:"}, max_connections_per_database=1)
|
||||||
|
connections = pool.connection_groups["one"].connections
|
||||||
|
assert 1 == len(connections)
|
||||||
|
|
||||||
|
def block_connection(pool):
|
||||||
|
with pool.connection("one"):
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
t = threading.Thread(target=block_connection, args=[pool])
|
||||||
|
t.start()
|
||||||
|
# Give thread time to grab the connection:
|
||||||
|
time.sleep(0.01)
|
||||||
|
# Thread should now have grabbed and locked a connection:
|
||||||
|
assert 1 == len([c for c in connections if c.lock.locked()])
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
# Now we attempt to use the connection. This should block.
|
||||||
|
with pool.connection("one") as conn:
|
||||||
|
# At this point, over 0.02 seconds should have passed
|
||||||
|
assert (time.time() - start) > 0.02
|
||||||
|
assert conn.lock.locked()
|
||||||
|
|
||||||
|
# Ensure thread has run to completion before ending test:
|
||||||
|
t.join()
|
||||||
|
# Connections should all be unlocked at the end
|
||||||
|
assert all(not c.lock.locked() for c in connections)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue