New Pool/ConnectionGroup implementation

Refs #569
This commit is contained in:
Simon Willison 2019-11-15 14:56:30 -08:00
commit 32cbfd2acd
2 changed files with 156 additions and 0 deletions

View file

@ -327,3 +327,94 @@ class Database:
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<Database: {}{}>".format(self.name, tags_str)
class ConnectionError(Exception):
pass
class ConnectionTimeoutError(ConnectionError):
pass
class Connection:
def __init__(self, name, connect_args=None, connect_kwargs=None):
self.name = name
self.connect_args = connect_args or tuple()
self.connect_kwargs = connect_kwargs or {}
self.lock = threading.Lock()
self._connection = None
def connection(self):
if self._connection is None:
self._connection = sqlite3.connect(
*self.connect_args, **self.connect_kwargs
)
return self._connection
def __repr__(self):
return "{} {} ({})".format(self.name, self._connection, self.lock)
class ConnectionGroup:
timeout = 2
def __init__(self, name, connect_args, connect_kwargs=None, limit=3):
self.name = name
self.connections = [
Connection(name, connect_args, connect_kwargs) for _ in range(limit)
]
self.limit = limit
self._semaphore = threading.Semaphore(value=limit)
@contextlib.contextmanager
def connection(self):
semaphore_aquired = False
reserved_connection = None
try:
semaphore_aquired = self._semaphore.acquire(timeout=self.timeout)
if not semaphore_aquired:
raise ConnectionTimeoutError(
"Timed out after {}s waiting for connection '{}'".format(
self.timeout, self.name
)
)
# Loop through connections attempting to aquire a lock
for connection in self.connections:
lock = connection.lock
if lock.acquire(False):
# We acquired the lock! use this one
reserved_connection = connection
break
else:
# If we get here, we failed to lock a connection even though
# the semaphore should have guaranteed it
raise ConnectionError(
"Failed to lock a connection despite the sempahore"
)
# We should have a connection now - yield it, then clean up locks
yield reserved_connection
finally:
reserved_connection.lock.release()
if semaphore_aquired:
self._semaphore.release()
class Pool:
def __init__(self, databases=None, max_connections_per_database=3):
self.max_connections_per_database = max_connections_per_database
self.databases = {}
self.connection_groups = {}
for key, value in (databases or {}).items():
self.add_database(key, value)
def add_database(self, name, filepath):
self.databases[name] = filepath
self.connection_groups[name] = ConnectionGroup(
name, [filepath], limit=self.max_connections_per_database
)
@contextlib.contextmanager
def connection(self, name):
with self.connection_groups[name].connection() as conn:
yield conn

65
tests/test_pool.py Normal file
View file

@ -0,0 +1,65 @@
import threading
import time
import pytest
from datasette.database import Pool
def test_lock_connection():
pool = Pool({"one": ":memory:"})
with pool.connection("one") as conn:
assert conn.lock.locked()
assert not conn.lock.locked()
def test_connect_if_one_connection_is_locked():
pool = Pool({"one": ":memory:"})
connections = pool.connection_groups["one"].connections
assert 3 == len(connections)
# They should all start unlocked:
assert all(not c.lock.locked() for c in connections)
# Now lock one for the duration of this test
first_connection = connections[0]
try:
first_connection.lock.acquire()
# This should give us a different connection
with pool.connection("one") as conn:
assert conn is not first_connection
assert conn.lock.locked()
# There should be only one UNLOCKED connection now
assert 1 == len([c for c in connections if not c.lock.locked()])
finally:
first_connection.lock.release()
# At this point, all connections should be unlocked
assert 3 == len([c for c in connections if not c.lock.locked()])
def test_block_until_connection_is_released():
# If all connections are already in use, block until one is released
pool = Pool({"one": ":memory:"}, max_connections_per_database=1)
connections = pool.connection_groups["one"].connections
assert 1 == len(connections)
def block_connection(pool):
with pool.connection("one"):
time.sleep(0.05)
t = threading.Thread(target=block_connection, args=[pool])
t.start()
# Give thread time to grab the connection:
time.sleep(0.01)
# Thread should now have grabbed and locked a connection:
assert 1 == len([c for c in connections if c.lock.locked()])
start = time.time()
# Now we attempt to use the connection. This should block.
with pool.connection("one") as conn:
# At this point, over 0.02 seconds should have passed
assert (time.time() - start) > 0.02
assert conn.lock.locked()
# Ensure thread has run to completion before ending test:
t.join()
# Connections should all be unlocked at the end
assert all(not c.lock.locked() for c in connections)