From f95ac19e7116335b8c87a2d75fde63f6cfdc7c3a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 6 Feb 2025 10:32:47 -0800 Subject: [PATCH] Fix to support replacing a database, closes #2465 --- datasette/database.py | 7 +++++-- tests/test_internals_database.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/datasette/database.py b/datasette/database.py index 4a0babfb..554f9fbf 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -32,6 +32,7 @@ AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) class Database: # For table counts stop at this many rows: count_limit = 10000 + _thread_local_id_counter = 1 def __init__( self, @@ -43,6 +44,8 @@ class Database: mode=None, ): self.name = None + self._thread_local_id = f"x{self._thread_local_id_counter}" + Database._thread_local_id_counter += 1 self.route = None self.ds = ds self.path = path @@ -278,11 +281,11 @@ class Database: # threaded mode def in_thread(): - conn = getattr(connections, self.name, None) + conn = getattr(connections, self._thread_local_id, None) if not conn: conn = self.connect() self.ds._prepare_connection(conn, self.name) - setattr(connections, self.name, conn) + setattr(connections, self._thread_local_id, conn) return fn(conn) return await asyncio.get_event_loop().run_in_executor( diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index edfc6bc7..eeaf8e9a 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -721,3 +721,34 @@ async def test_hidden_tables(app_client): "r_parent", "r_rowid", ] + + +@pytest.mark.asyncio +async def test_replace_database(tmpdir): + path1 = str(tmpdir / "data1.db") + (tmpdir / "two").mkdir() + path2 = str(tmpdir / "two" / "data1.db") + sqlite3.connect(path1).executescript( + """ + create table t (id integer primary key); + insert into t (id) values (1); + insert into t (id) values (2); + """ + ) + sqlite3.connect(path2).executescript( + """ + create table t (id integer primary key); + insert into t (id) values (1); + """ + ) + datasette = Datasette([path1]) + db = datasette.get_database("data1") + count = (await db.execute("select count(*) from t")).first()[0] + assert count == 2 + # Now replace that database + datasette.get_database("data1").close() + datasette.remove_database("data1") + datasette.add_database(Database(datasette, path2), "data1") + db2 = datasette.get_database("data1") + count = (await db2.execute("select count(*) from t")).first()[0] + assert count == 1