mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
parent
a9909c29cc
commit
32cbfd2acd
2 changed files with 156 additions and 0 deletions
|
|
@ -327,3 +327,94 @@ class Database:
|
|||
if tags:
|
||||
tags_str = " ({})".format(", ".join(tags))
|
||||
return "<Database: {}{}>".format(self.name, tags_str)
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionTimeoutError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, name, connect_args=None, connect_kwargs=None):
|
||||
self.name = name
|
||||
self.connect_args = connect_args or tuple()
|
||||
self.connect_kwargs = connect_kwargs or {}
|
||||
self.lock = threading.Lock()
|
||||
self._connection = None
|
||||
|
||||
def connection(self):
|
||||
if self._connection is None:
|
||||
self._connection = sqlite3.connect(
|
||||
*self.connect_args, **self.connect_kwargs
|
||||
)
|
||||
return self._connection
|
||||
|
||||
def __repr__(self):
|
||||
return "{} {} ({})".format(self.name, self._connection, self.lock)
|
||||
|
||||
|
||||
class ConnectionGroup:
|
||||
timeout = 2
|
||||
|
||||
def __init__(self, name, connect_args, connect_kwargs=None, limit=3):
|
||||
self.name = name
|
||||
self.connections = [
|
||||
Connection(name, connect_args, connect_kwargs) for _ in range(limit)
|
||||
]
|
||||
self.limit = limit
|
||||
self._semaphore = threading.Semaphore(value=limit)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def connection(self):
|
||||
semaphore_aquired = False
|
||||
reserved_connection = None
|
||||
try:
|
||||
semaphore_aquired = self._semaphore.acquire(timeout=self.timeout)
|
||||
if not semaphore_aquired:
|
||||
raise ConnectionTimeoutError(
|
||||
"Timed out after {}s waiting for connection '{}'".format(
|
||||
self.timeout, self.name
|
||||
)
|
||||
)
|
||||
# Loop through connections attempting to aquire a lock
|
||||
for connection in self.connections:
|
||||
lock = connection.lock
|
||||
if lock.acquire(False):
|
||||
# We acquired the lock! use this one
|
||||
reserved_connection = connection
|
||||
break
|
||||
else:
|
||||
# If we get here, we failed to lock a connection even though
|
||||
# the semaphore should have guaranteed it
|
||||
raise ConnectionError(
|
||||
"Failed to lock a connection despite the sempahore"
|
||||
)
|
||||
# We should have a connection now - yield it, then clean up locks
|
||||
yield reserved_connection
|
||||
finally:
|
||||
reserved_connection.lock.release()
|
||||
if semaphore_aquired:
|
||||
self._semaphore.release()
|
||||
|
||||
|
||||
class Pool:
|
||||
def __init__(self, databases=None, max_connections_per_database=3):
|
||||
self.max_connections_per_database = max_connections_per_database
|
||||
self.databases = {}
|
||||
self.connection_groups = {}
|
||||
for key, value in (databases or {}).items():
|
||||
self.add_database(key, value)
|
||||
|
||||
def add_database(self, name, filepath):
|
||||
self.databases[name] = filepath
|
||||
self.connection_groups[name] = ConnectionGroup(
|
||||
name, [filepath], limit=self.max_connections_per_database
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def connection(self, name):
|
||||
with self.connection_groups[name].connection() as conn:
|
||||
yield conn
|
||||
|
|
|
|||
65
tests/test_pool.py
Normal file
65
tests/test_pool.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from datasette.database import Pool
|
||||
|
||||
|
||||
def test_lock_connection():
|
||||
pool = Pool({"one": ":memory:"})
|
||||
with pool.connection("one") as conn:
|
||||
assert conn.lock.locked()
|
||||
assert not conn.lock.locked()
|
||||
|
||||
|
||||
def test_connect_if_one_connection_is_locked():
|
||||
pool = Pool({"one": ":memory:"})
|
||||
connections = pool.connection_groups["one"].connections
|
||||
assert 3 == len(connections)
|
||||
# They should all start unlocked:
|
||||
assert all(not c.lock.locked() for c in connections)
|
||||
# Now lock one for the duration of this test
|
||||
first_connection = connections[0]
|
||||
try:
|
||||
first_connection.lock.acquire()
|
||||
# This should give us a different connection
|
||||
with pool.connection("one") as conn:
|
||||
assert conn is not first_connection
|
||||
assert conn.lock.locked()
|
||||
# There should be only one UNLOCKED connection now
|
||||
assert 1 == len([c for c in connections if not c.lock.locked()])
|
||||
finally:
|
||||
first_connection.lock.release()
|
||||
# At this point, all connections should be unlocked
|
||||
assert 3 == len([c for c in connections if not c.lock.locked()])
|
||||
|
||||
|
||||
def test_block_until_connection_is_released():
|
||||
# If all connections are already in use, block until one is released
|
||||
pool = Pool({"one": ":memory:"}, max_connections_per_database=1)
|
||||
connections = pool.connection_groups["one"].connections
|
||||
assert 1 == len(connections)
|
||||
|
||||
def block_connection(pool):
|
||||
with pool.connection("one"):
|
||||
time.sleep(0.05)
|
||||
|
||||
t = threading.Thread(target=block_connection, args=[pool])
|
||||
t.start()
|
||||
# Give thread time to grab the connection:
|
||||
time.sleep(0.01)
|
||||
# Thread should now have grabbed and locked a connection:
|
||||
assert 1 == len([c for c in connections if c.lock.locked()])
|
||||
|
||||
start = time.time()
|
||||
# Now we attempt to use the connection. This should block.
|
||||
with pool.connection("one") as conn:
|
||||
# At this point, over 0.02 seconds should have passed
|
||||
assert (time.time() - start) > 0.02
|
||||
assert conn.lock.locked()
|
||||
|
||||
# Ensure thread has run to completion before ending test:
|
||||
t.join()
|
||||
# Connections should all be unlocked at the end
|
||||
assert all(not c.lock.locked() for c in connections)
|
||||
Loading…
Add table
Add a link
Reference in a new issue