Compare commits

...

2 commits

Author SHA1 Message Date
Simon Willison
8895c4a202 Fixed the rest of the warnings
Refs https://github.com/simonw/datasette/pull/2615#issuecomment-3649771920
2025-12-13 12:16:57 -08:00
Simon Willison
35ea721469 Fixed hundreds of database connection closing warnings
From 409 warnings down to 52 warnings.

Claude Code says:

Fixed connection leaks in:
1. datasette/utils/sqlite.py - _sqlite_version() now closes connection
2. datasette/cli.py - --create flag now closes connection
3. datasette/app.py - _versions() now closes connection
4. datasette/utils/__init__.py - detect_json1() now closes connection when created internally
5. tests/conftest.py - pytest_report_header() now closes connection
6. tests/utils.py - has_load_extension() now closes connection
7. tests/fixtures.py - app_client_no_files and CLI fixtures now close connections
8. tests/test_api_write.py - ds_write fixture closes both connections
9. tests/test_cli.py - Multiple test functions now close connections
10. tests/test_config_dir.py - config_dir and config_dir_client fixtures now close connections
11. tests/test_crossdb.py - Loop connections now closed
12. tests/test_internals_database.py - Test setup connections now closed
13. tests/test_plugins.py - view_names_client fixture and test now close connections
14. tests/test_utils.py - Multiple test functions now close connections

Refs #2614
2025-12-12 22:38:04 -08:00
19 changed files with 603 additions and 292 deletions

View file

@ -1608,6 +1608,7 @@ class Datasette:
break break
except importlib.metadata.PackageNotFoundError: except importlib.metadata.PackageNotFoundError:
pass pass
conn.close()
return info return info
def _plugins(self, request=None, all=False): def _plugins(self, request=None, all=False):

View file

@ -619,7 +619,9 @@ def serve(
for file in file_paths: for file in file_paths:
if not pathlib.Path(file).exists(): if not pathlib.Path(file).exists():
if create: if create:
sqlite3.connect(file).execute("vacuum") conn = sqlite3.connect(file)
conn.execute("vacuum")
conn.close()
else: else:
raise click.ClickException( raise click.ClickException(
"Invalid value for '[FILES]...': Path '{}' does not exist.".format( "Invalid value for '[FILES]...': Path '{}' does not exist.".format(

View file

@ -28,6 +28,9 @@ connections = threading.local()
AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file"))
# Sentinel object to signal write thread shutdown
_SHUTDOWN_SENTINEL = object()
class Database: class Database:
# For table counts stop at this many rows: # For table counts stop at this many rows:
@ -62,10 +65,25 @@ class Database:
# These are used when in non-threaded mode: # These are used when in non-threaded mode:
self._read_connection = None self._read_connection = None
self._write_connection = None self._write_connection = None
# This is used to track all file connections so they can be closed # This is used to track all connections so they can be closed
self._all_file_connections = [] self._all_connections = []
self._closed = False
self.mode = mode self.mode = mode
def __del__(self):
# Ensure connections are closed when Database is garbage collected
# This prevents ResourceWarning about unclosed database connections
if not self._closed:
# Close all tracked connections without executor cleanup
# (executor might already be gone during garbage collection)
for connection in self._all_connections:
try:
connection.close()
except Exception:
pass
self._all_connections.clear()
self._closed = True
@property @property
def cached_table_counts(self): def cached_table_counts(self):
if self._cached_table_counts is not None: if self._cached_table_counts is not None:
@ -103,9 +121,12 @@ class Database:
) )
if not write: if not write:
conn.execute("PRAGMA query_only=1") conn.execute("PRAGMA query_only=1")
self._all_connections.append(conn)
return conn return conn
if self.is_memory: if self.is_memory:
return sqlite3.connect(":memory:", uri=True) conn = sqlite3.connect(":memory:", uri=True, check_same_thread=False)
self._all_connections.append(conn)
return conn
# mode=ro or immutable=1? # mode=ro or immutable=1?
if self.is_mutable: if self.is_mutable:
@ -122,13 +143,69 @@ class Database:
conn = sqlite3.connect( conn = sqlite3.connect(
f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs
) )
self._all_file_connections.append(conn) self._all_connections.append(conn)
return conn return conn
def close(self): def close(self):
# Close all connections - useful to avoid running out of file handles in tests # Close all connections - useful to avoid running out of file handles in tests
for connection in self._all_file_connections: self._closed = True
connection.close() # First, signal the write thread to shut down if it exists
if self._write_thread is not None and self._write_queue is not None:
self._write_queue.put(_SHUTDOWN_SENTINEL)
self._write_thread.join(timeout=1.0)
# Clear the instance variable references (connections will be closed below)
self._read_connection = None
self._write_connection = None
# Close and clear thread-local connection if it exists in the current thread
main_thread_conn = getattr(connections, self._thread_local_id, None)
if main_thread_conn is not None:
try:
main_thread_conn.close()
except Exception:
pass
delattr(connections, self._thread_local_id)
# If executor is available, use a barrier to ensure cleanup runs on ALL threads
thread_local_id = self._thread_local_id
if self.ds.executor is not None:
import concurrent.futures
max_workers = getattr(self.ds.executor, "_max_workers", None) or 1
barrier = threading.Barrier(max_workers, timeout=2.0)
def clear_thread_local():
# Close and clear this database's thread-local connection in this thread
conn = getattr(connections, thread_local_id, None)
if conn is not None:
try:
conn.close()
except Exception:
pass # Connection might already be closed
delattr(connections, thread_local_id)
# Wait for all threads to reach this point - this ensures
# all threads are processing cleanup simultaneously
try:
barrier.wait()
except threading.BrokenBarrierError:
pass
try:
# Submit exactly max_workers tasks - the barrier ensures all
# threads must be occupied with our cleanup tasks
futures = [
self.ds.executor.submit(clear_thread_local)
for _ in range(max_workers)
]
# Wait for all cleanup tasks to complete
concurrent.futures.wait(futures, timeout=3.0)
except Exception:
pass # Executor might be shutting down
# Close all tracked connections
for connection in self._all_connections:
try:
connection.close()
except Exception:
pass # Connection might already be closed
self._all_connections.clear()
async def execute_write(self, sql, params=None, block=True): async def execute_write(self, sql, params=None, block=True):
def _inner(conn): def _inner(conn):
@ -178,7 +255,7 @@ class Database:
finally: finally:
isolated_connection.close() isolated_connection.close()
try: try:
self._all_file_connections.remove(isolated_connection) self._all_connections.remove(isolated_connection)
except ValueError: except ValueError:
# Was probably a memory connection # Was probably a memory connection
pass pass
@ -242,6 +319,15 @@ class Database:
conn_exception = e conn_exception = e
while True: while True:
task = self._write_queue.get() task = self._write_queue.get()
# Check for shutdown sentinel
if task is _SHUTDOWN_SENTINEL:
if conn is not None:
conn.close()
try:
self._all_connections.remove(conn)
except ValueError:
pass
return
if conn_exception is not None: if conn_exception is not None:
result = conn_exception result = conn_exception
else: else:
@ -256,7 +342,7 @@ class Database:
finally: finally:
isolated_connection.close() isolated_connection.close()
try: try:
self._all_file_connections.remove(isolated_connection) self._all_connections.remove(isolated_connection)
except ValueError: except ValueError:
# Was probably a memory connection # Was probably a memory connection
pass pass
@ -284,6 +370,10 @@ class Database:
# threaded mode # threaded mode
def in_thread(): def in_thread():
conn = getattr(connections, self._thread_local_id, None) conn = getattr(connections, self._thread_local_id, None)
# Check if database was closed - if so, clear the stale cached connection
if conn and self._closed:
delattr(connections, self._thread_local_id)
conn = None
if not conn: if not conn:
conn = self.connect() conn = self.connect()
self.ds._prepare_connection(conn, self.name) self.ds._prepare_connection(conn, self.name)

View file

@ -671,13 +671,18 @@ def detect_fts_sql(table):
def detect_json1(conn=None): def detect_json1(conn=None):
close_conn = False
if conn is None: if conn is None:
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
close_conn = True
try: try:
conn.execute("SELECT json('{}')") conn.execute("SELECT json('{}')")
return True return True
except Exception: except Exception:
return False return False
finally:
if close_conn:
conn.close()
def table_columns(conn, table): def table_columns(conn, table):

View file

@ -20,15 +20,16 @@ def sqlite_version():
def _sqlite_version(): def _sqlite_version():
return tuple( conn = sqlite3.connect(":memory:")
map( try:
int, return tuple(
sqlite3.connect(":memory:") map(
.execute("select sqlite_version()") int,
.fetchone()[0] conn.execute("select sqlite_version()").fetchone()[0].split("."),
.split("."), )
) )
) finally:
conn.close()
def supports_table_xinfo(): def supports_table_xinfo():

View file

@ -30,6 +30,7 @@ UNDOCUMENTED_PERMISSIONS = {
} }
_ds_client = None _ds_client = None
_ds_instance = None
def wait_until_responds(url, timeout=5.0, client=httpx, **kwargs): def wait_until_responds(url, timeout=5.0, client=httpx, **kwargs):
@ -50,7 +51,7 @@ async def ds_client():
from .fixtures import CONFIG, METADATA, PLUGINS_DIR from .fixtures import CONFIG, METADATA, PLUGINS_DIR
import secrets import secrets
global _ds_client global _ds_client, _ds_instance
if _ds_client is not None: if _ds_client is not None:
return _ds_client return _ds_client
@ -86,13 +87,15 @@ async def ds_client():
await db.execute_write_fn(prepare) await db.execute_write_fn(prepare)
await ds.invoke_startup() await ds.invoke_startup()
_ds_client = ds.client _ds_client = ds.client
_ds_instance = ds
return _ds_client return _ds_client
def pytest_report_header(config): def pytest_report_header(config):
return "SQLite: {}".format( conn = sqlite3.connect(":memory:")
sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0] version = conn.execute("select sqlite_version()").fetchone()[0]
) conn.close()
return "SQLite: {}".format(version)
def pytest_configure(config): def pytest_configure(config):
@ -106,6 +109,19 @@ def pytest_unconfigure(config):
del sys._called_from_test del sys._called_from_test
# Clean up the global ds_client fixture
global _ds_instance
if _ds_instance is not None:
# Close databases first (while executor is still running)
for db in _ds_instance.databases.values():
db.close()
if hasattr(_ds_instance, "_internal_database"):
_ds_instance._internal_database.close()
# Then shut down executor
if _ds_instance.executor is not None:
_ds_instance.executor.shutdown(wait=True)
_ds_instance = None
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
# Ensure test_cli.py and test_black.py and test_inspect.py run first before any asyncio code kicks in # Ensure test_cli.py and test_black.py and test_inspect.py run first before any asyncio code kicks in
@ -217,6 +233,27 @@ def ds_localhost_http_server():
yield ds_proc yield ds_proc
# Shut it down at the end of the pytest session # Shut it down at the end of the pytest session
ds_proc.terminate() ds_proc.terminate()
ds_proc.wait()
if ds_proc.stdout:
ds_proc.stdout.close()
def wait_until_uds_responds(uds_path, timeout=5.0):
"""Wait for a Unix domain socket to accept connections."""
import socket as socket_module
start = time.time()
while time.time() - start < timeout:
sock = socket_module.socket(socket_module.AF_UNIX, socket_module.SOCK_STREAM)
try:
sock.connect(uds_path)
# Connection successful, now close and return
sock.close()
return
except (ConnectionRefusedError, FileNotFoundError):
sock.close()
time.sleep(0.1)
raise AssertionError("Timed out waiting for {} to respond".format(uds_path))
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -232,15 +269,16 @@ def ds_unix_domain_socket_server(tmp_path_factory):
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
cwd=tempfile.gettempdir(), cwd=tempfile.gettempdir(),
) )
# Poll until available # Poll until available using raw socket to avoid httpx connection pool leaks
transport = httpx.HTTPTransport(uds=uds) wait_until_uds_responds(uds)
client = httpx.Client(transport=transport)
wait_until_responds("http://localhost/_memory.json", client=client)
# Check it started successfully # Check it started successfully
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
yield ds_proc, uds yield ds_proc, uds
# Shut it down at the end of the pytest session # Shut it down at the end of the pytest session
ds_proc.terminate() ds_proc.terminate()
ds_proc.wait()
if ds_proc.stdout:
ds_proc.stdout.close()
# Import fixtures from fixtures.py to make them available # Import fixtures from fixtures.py to make them available

View file

@ -171,11 +171,15 @@ def make_app_client(
crossdb=crossdb, crossdb=crossdb,
) )
yield TestClient(ds) yield TestClient(ds)
# Close as many database connections as possible # Close all database connections first (while executor is still running)
# to try and avoid too many open files error # This allows db.close() to submit cleanup tasks to executor threads
for db in ds.databases.values(): for db in ds.databases.values():
if not db.is_memory: db.close()
db.close() if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -188,6 +192,14 @@ def app_client():
def app_client_no_files(): def app_client_no_files():
ds = Datasette([]) ds = Datasette([])
yield TestClient(ds) yield TestClient(ds)
# Close databases first (while executor is still running)
for db in ds.databases.values():
db.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -833,6 +845,7 @@ def cli(db_filename, config, metadata, plugins_path, recreate, extra_db_filename
for sql, params in TABLE_PARAMETERIZED_SQL: for sql, params in TABLE_PARAMETERIZED_SQL:
with conn: with conn:
conn.execute(sql, params) conn.execute(sql, params)
conn.close()
print(f"Test tables written to {db_filename}") print(f"Test tables written to {db_filename}")
if metadata: if metadata:
with open(metadata, "w") as fp: with open(metadata, "w") as fp:
@ -861,6 +874,7 @@ def cli(db_filename, config, metadata, plugins_path, recreate, extra_db_filename
pathlib.Path(extra_db_filename).unlink() pathlib.Path(extra_db_filename).unlink()
conn = sqlite3.connect(extra_db_filename) conn = sqlite3.connect(extra_db_filename)
conn.executescript(EXTRA_DATABASE_SQL) conn.executescript(EXTRA_DATABASE_SQL)
conn.close()
print(f"Test tables written to {extra_db_filename}") print(f"Test tables written to {extra_db_filename}")

View file

@ -17,10 +17,11 @@ def ds_write(tmp_path_factory):
db.execute( db.execute(
"create table docs (id integer primary key, title text, score float, age integer)" "create table docs (id integer primary key, title text, score float, age integer)"
) )
db1.close()
db2.close()
ds = Datasette([db_path], immutables=[db_path_immutable]) ds = Datasette([db_path], immutables=[db_path_immutable])
ds.root_enabled = True ds.root_enabled = True
yield ds yield ds
db.close()
def write_token(ds, actor_id="root", permissions=None): def write_token(ds, actor_id="root", permissions=None):

View file

@ -442,7 +442,9 @@ def test_serve_duplicate_database_names(tmpdir):
nested.mkdir() nested.mkdir()
db_2_path = str(tmpdir / "nested" / "db.db") db_2_path = str(tmpdir / "nested" / "db.db")
for path in (db_1_path, db_2_path): for path in (db_1_path, db_2_path):
sqlite3.connect(path).execute("vacuum") conn = sqlite3.connect(path)
conn.execute("vacuum")
conn.close()
result = runner.invoke(cli, [db_1_path, db_2_path, "--get", "/-/databases.json"]) result = runner.invoke(cli, [db_1_path, db_2_path, "--get", "/-/databases.json"])
assert result.exit_code == 0, result.output assert result.exit_code == 0, result.output
databases = json.loads(result.output) databases = json.loads(result.output)
@ -456,7 +458,9 @@ def test_weird_database_names(tmpdir, filename):
# https://github.com/simonw/datasette/issues/1181 # https://github.com/simonw/datasette/issues/1181
runner = CliRunner() runner = CliRunner()
db_path = str(tmpdir / filename) db_path = str(tmpdir / filename)
sqlite3.connect(db_path).execute("vacuum") conn = sqlite3.connect(db_path)
conn.execute("vacuum")
conn.close()
result1 = runner.invoke(cli, [db_path, "--get", "/"]) result1 = runner.invoke(cli, [db_path, "--get", "/"])
assert result1.exit_code == 0, result1.output assert result1.exit_code == 0, result1.output
filename_no_stem = filename.rsplit(".", 1)[0] filename_no_stem = filename.rsplit(".", 1)[0]
@ -493,7 +497,9 @@ def test_duplicate_database_files_error(tmpdir):
"""Test that passing the same database file multiple times raises an error""" """Test that passing the same database file multiple times raises an error"""
runner = CliRunner() runner = CliRunner()
db_path = str(tmpdir / "test.db") db_path = str(tmpdir / "test.db")
sqlite3.connect(db_path).execute("vacuum") conn = sqlite3.connect(db_path)
conn.execute("vacuum")
conn.close()
# Test with exact duplicate # Test with exact duplicate
result = runner.invoke(cli, ["serve", db_path, db_path, "--get", "/"]) result = runner.invoke(cli, ["serve", db_path, db_path, "--get", "/"])
@ -512,7 +518,9 @@ def test_duplicate_database_files_error(tmpdir):
config_dir = tmpdir / "config" config_dir = tmpdir / "config"
config_dir.mkdir() config_dir.mkdir()
config_db_path = str(config_dir / "data.db") config_db_path = str(config_dir / "data.db")
sqlite3.connect(config_db_path).execute("vacuum") conn = sqlite3.connect(config_db_path)
conn.execute("vacuum")
conn.close()
result3 = runner.invoke( result3 = runner.invoke(
cli, ["serve", config_db_path, str(config_dir), "--get", "/"] cli, ["serve", config_db_path, str(config_dir), "--get", "/"]
@ -523,7 +531,9 @@ def test_duplicate_database_files_error(tmpdir):
# Test that mixing a file NOT in the directory with a directory works fine # Test that mixing a file NOT in the directory with a directory works fine
other_db_path = str(tmpdir / "other.db") other_db_path = str(tmpdir / "other.db")
sqlite3.connect(other_db_path).execute("vacuum") conn = sqlite3.connect(other_db_path)
conn.execute("vacuum")
conn.close()
result4 = runner.invoke( result4 = runner.invoke(
cli, ["serve", other_db_path, str(config_dir), "--get", "/-/databases.json"] cli, ["serve", other_db_path, str(config_dir), "--get", "/-/databases.json"]

View file

@ -5,12 +5,13 @@ import socket
@pytest.mark.serial @pytest.mark.serial
def test_serve_localhost_http(ds_localhost_http_server): def test_serve_localhost_http(ds_localhost_http_server):
response = httpx.get("http://localhost:8041/_memory.json") with httpx.Client() as client:
assert { response = client.get("http://localhost:8041/_memory.json")
"database": "_memory", assert {
"path": "/_memory", "database": "_memory",
"tables": [], "path": "/_memory",
}.items() <= response.json().items() "tables": [],
}.items() <= response.json().items()
@pytest.mark.serial @pytest.mark.serial
@ -20,10 +21,13 @@ def test_serve_localhost_http(ds_localhost_http_server):
def test_serve_unix_domain_socket(ds_unix_domain_socket_server): def test_serve_unix_domain_socket(ds_unix_domain_socket_server):
_, uds = ds_unix_domain_socket_server _, uds = ds_unix_domain_socket_server
transport = httpx.HTTPTransport(uds=uds) transport = httpx.HTTPTransport(uds=uds)
client = httpx.Client(transport=transport) try:
response = client.get("http://localhost/_memory.json") with httpx.Client(transport=transport) as client:
assert { response = client.get("http://localhost/_memory.json")
"database": "_memory", assert {
"path": "/_memory", "database": "_memory",
"tables": [], "path": "/_memory",
}.items() <= response.json().items() "tables": [],
}.items() <= response.json().items()
finally:
transport.close()

View file

@ -62,6 +62,7 @@ def config_dir(tmp_path_factory):
; ;
""" """
) )
db.close()
# Mark "immutable.db" as immutable # Mark "immutable.db" as immutable
(config_dir / "inspect-data.json").write_text( (config_dir / "inspect-data.json").write_text(
@ -97,6 +98,10 @@ def test_invalid_settings(config_dir):
def config_dir_client(config_dir): def config_dir_client(config_dir):
ds = Datasette([], config_dir=config_dir) ds = Datasette([], config_dir=config_dir)
yield _TestClient(ds) yield _TestClient(ds)
for db in ds.databases.values():
db.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
def test_settings(config_dir_client): def test_settings(config_dir_client):

View file

@ -43,6 +43,7 @@ def test_crossdb_warning_if_too_many_databases(tmp_path_factory):
path = str(db_dir / "db_{}.db".format(i)) path = str(db_dir / "db_{}.db".format(i))
conn = sqlite3.connect(path) conn = sqlite3.connect(path)
conn.execute("vacuum") conn.execute("vacuum")
conn.close()
dbs.append(path) dbs.append(path)
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(

View file

@ -23,6 +23,14 @@ async def datasette_with_plugin():
yield datasette yield datasette
finally: finally:
datasette.pm.unregister(name="undo") datasette.pm.unregister(name="undo")
# Close databases first (while executor is still running)
for db in datasette.databases.values():
db.close()
if hasattr(datasette, "_internal_database"):
datasette._internal_database.close()
# Then shut down executor
if datasette.executor is not None:
datasette.executor.shutdown(wait=True)
# -- end datasette_with_plugin_fixture -- # -- end datasette_with_plugin_fixture --

View file

@ -407,55 +407,68 @@ async def test_array_facet_results(ds_client):
@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") @pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module")
async def test_array_facet_handle_duplicate_tags(): async def test_array_facet_handle_duplicate_tags():
ds = Datasette([], memory=True) ds = Datasette([], memory=True)
db = ds.add_database(Database(ds, memory_name="test_array_facet")) try:
await db.execute_write("create table otters(name text, tags text)") db = ds.add_database(Database(ds, memory_name="test_array_facet"))
for name, tags in ( await db.execute_write("create table otters(name text, tags text)")
("Charles", ["friendly", "cunning", "friendly"]), for name, tags in (
("Shaun", ["cunning", "empathetic", "friendly"]), ("Charles", ["friendly", "cunning", "friendly"]),
("Tracy", ["empathetic", "eager"]), ("Shaun", ["cunning", "empathetic", "friendly"]),
): ("Tracy", ["empathetic", "eager"]),
await db.execute_write( ):
"insert into otters (name, tags) values (?, ?)", [name, json.dumps(tags)] await db.execute_write(
) "insert into otters (name, tags) values (?, ?)",
[name, json.dumps(tags)],
)
response = await ds.client.get("/test_array_facet/otters.json?_facet_array=tags") response = await ds.client.get(
assert response.json()["facet_results"]["results"]["tags"] == { "/test_array_facet/otters.json?_facet_array=tags"
"name": "tags", )
"type": "array", assert response.json()["facet_results"]["results"]["tags"] == {
"results": [ "name": "tags",
{ "type": "array",
"value": "cunning", "results": [
"label": "cunning", {
"count": 2, "value": "cunning",
"toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=cunning", "label": "cunning",
"selected": False, "count": 2,
}, "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=cunning",
{ "selected": False,
"value": "empathetic", },
"label": "empathetic", {
"count": 2, "value": "empathetic",
"toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=empathetic", "label": "empathetic",
"selected": False, "count": 2,
}, "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=empathetic",
{ "selected": False,
"value": "friendly", },
"label": "friendly", {
"count": 2, "value": "friendly",
"toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=friendly", "label": "friendly",
"selected": False, "count": 2,
}, "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=friendly",
{ "selected": False,
"value": "eager", },
"label": "eager", {
"count": 1, "value": "eager",
"toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=eager", "label": "eager",
"selected": False, "count": 1,
}, "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=eager",
], "selected": False,
"hideable": True, },
"toggle_url": "/test_array_facet/otters.json", ],
"truncated": False, "hideable": True,
} "toggle_url": "/test_array_facet/otters.json",
"truncated": False,
}
finally:
# Close databases first (while executor is still running)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -513,99 +526,124 @@ async def test_date_facet_results(ds_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_array_with_blanks_and_nulls(): async def test_json_array_with_blanks_and_nulls():
ds = Datasette([], memory=True) ds = Datasette([], memory=True)
db = ds.add_database(Database(ds, memory_name="test_json_array")) try:
await db.execute_write("create table foo(json_column text)") db = ds.add_database(Database(ds, memory_name="test_json_array"))
for value in ('["a", "b", "c"]', '["a", "b"]', "", None): await db.execute_write("create table foo(json_column text)")
await db.execute_write("insert into foo (json_column) values (?)", [value]) for value in ('["a", "b", "c"]', '["a", "b"]', "", None):
response = await ds.client.get("/test_json_array/foo.json?_extra=suggested_facets") await db.execute_write("insert into foo (json_column) values (?)", [value])
data = response.json() response = await ds.client.get(
assert data["suggested_facets"] == [ "/test_json_array/foo.json?_extra=suggested_facets"
{ )
"name": "json_column", data = response.json()
"type": "array", assert data["suggested_facets"] == [
"toggle_url": "http://localhost/test_json_array/foo.json?_extra=suggested_facets&_facet_array=json_column", {
} "name": "json_column",
] "type": "array",
"toggle_url": "http://localhost/test_json_array/foo.json?_extra=suggested_facets&_facet_array=json_column",
}
]
finally:
# Close databases first (while executor is still running)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_facet_size(): async def test_facet_size():
ds = Datasette([], memory=True, settings={"max_returned_rows": 50}) ds = Datasette([], memory=True, settings={"max_returned_rows": 50})
db = ds.add_database(Database(ds, memory_name="test_facet_size"))
await db.execute_write("create table neighbourhoods(city text, neighbourhood text)")
for i in range(1, 51):
for j in range(1, 4):
await db.execute_write(
"insert into neighbourhoods (city, neighbourhood) values (?, ?)",
["City {}".format(i), "Neighbourhood {}".format(j)],
)
response = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_extra=suggested_facets"
)
data = response.json()
assert data["suggested_facets"] == [
{
"name": "neighbourhood",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_extra=suggested_facets&_facet=neighbourhood",
}
]
# Bump up _facet_size= to suggest city too
response2 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets"
)
data2 = response2.json()
assert sorted(data2["suggested_facets"], key=lambda f: f["name"]) == [
{
"name": "city",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=city",
},
{
"name": "neighbourhood",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=neighbourhood",
},
]
# Facet by city should return expected number of results
response3 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city"
)
data3 = response3.json()
assert len(data3["facet_results"]["results"]["city"]["results"]) == 50
# Reduce max_returned_rows and check that it's respected
ds._settings["max_returned_rows"] = 20
response4 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city"
)
data4 = response4.json()
assert len(data4["facet_results"]["results"]["city"]["results"]) == 20
# Test _facet_size=max
response5 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=max&_facet=city"
)
data5 = response5.json()
assert len(data5["facet_results"]["results"]["city"]["results"]) == 20
# Now try messing with facet_size in the table metadata
orig_config = ds.config
try: try:
ds.config = { db = ds.add_database(Database(ds, memory_name="test_facet_size"))
"databases": { await db.execute_write(
"test_facet_size": {"tables": {"neighbourhoods": {"facet_size": 6}}} "create table neighbourhoods(city text, neighbourhood text)"
}
}
response6 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet=city"
) )
data6 = response6.json() for i in range(1, 51):
assert len(data6["facet_results"]["results"]["city"]["results"]) == 6 for j in range(1, 4):
# Setting it to max bumps it up to 50 again await db.execute_write(
ds.config["databases"]["test_facet_size"]["tables"]["neighbourhoods"][ "insert into neighbourhoods (city, neighbourhood) values (?, ?)",
"facet_size" ["City {}".format(i), "Neighbourhood {}".format(j)],
] = "max" )
data7 = ( response = await ds.client.get(
await ds.client.get("/test_facet_size/neighbourhoods.json?_facet=city") "/test_facet_size/neighbourhoods.json?_extra=suggested_facets"
).json() )
assert len(data7["facet_results"]["results"]["city"]["results"]) == 20 data = response.json()
assert data["suggested_facets"] == [
{
"name": "neighbourhood",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_extra=suggested_facets&_facet=neighbourhood",
}
]
# Bump up _facet_size= to suggest city too
response2 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets"
)
data2 = response2.json()
assert sorted(data2["suggested_facets"], key=lambda f: f["name"]) == [
{
"name": "city",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=city",
},
{
"name": "neighbourhood",
"toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=neighbourhood",
},
]
# Facet by city should return expected number of results
response3 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city"
)
data3 = response3.json()
assert len(data3["facet_results"]["results"]["city"]["results"]) == 50
# Reduce max_returned_rows and check that it's respected
ds._settings["max_returned_rows"] = 20
response4 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city"
)
data4 = response4.json()
assert len(data4["facet_results"]["results"]["city"]["results"]) == 20
# Test _facet_size=max
response5 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet_size=max&_facet=city"
)
data5 = response5.json()
assert len(data5["facet_results"]["results"]["city"]["results"]) == 20
# Now try messing with facet_size in the table metadata
orig_config = ds.config
try:
ds.config = {
"databases": {
"test_facet_size": {"tables": {"neighbourhoods": {"facet_size": 6}}}
}
}
response6 = await ds.client.get(
"/test_facet_size/neighbourhoods.json?_facet=city"
)
data6 = response6.json()
assert len(data6["facet_results"]["results"]["city"]["results"]) == 6
# Setting it to max bumps it up to 50 again
ds.config["databases"]["test_facet_size"]["tables"]["neighbourhoods"][
"facet_size"
] = "max"
data7 = (
await ds.client.get("/test_facet_size/neighbourhoods.json?_facet=city")
).json()
assert len(data7["facet_results"]["results"]["city"]["results"]) == 20
finally:
ds.config = orig_config
finally: finally:
ds.config = orig_config # Close databases first (while executor is still running)
# This allows db.close() to clear thread-local storage in executor threads
for db_obj in list(ds.databases.values()):
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
def test_other_types_of_facet_in_metadata(): def test_other_types_of_facet_in_metadata():
@ -648,20 +686,30 @@ async def test_conflicting_facet_names_json(ds_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_facet_against_in_memory_database(): async def test_facet_against_in_memory_database():
ds = Datasette() ds = Datasette()
db = ds.add_memory_database("mem") try:
await db.execute_write( db = ds.add_memory_database("mem")
"create table t (id integer primary key, name text, name2 text)" await db.execute_write(
) "create table t (id integer primary key, name text, name2 text)"
to_insert = [{"name": "one", "name2": "1"} for _ in range(800)] + [ )
{"name": "two", "name2": "2"} for _ in range(300) to_insert = [{"name": "one", "name2": "1"} for _ in range(800)] + [
] {"name": "two", "name2": "2"} for _ in range(300)
await db.execute_write_many( ]
"insert into t (name, name2) values (:name, :name2)", to_insert await db.execute_write_many(
) "insert into t (name, name2) values (:name, :name2)", to_insert
response1 = await ds.client.get("/mem/t") )
assert response1.status_code == 200 response1 = await ds.client.get("/mem/t")
response2 = await ds.client.get("/mem/t?_facet=name&_facet=name2") assert response1.status_code == 200
assert response2.status_code == 200 response2 = await ds.client.get("/mem/t?_facet=name&_facet=name2")
assert response2.status_code == 200
finally:
# Close databases first (while executor is still running)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -698,3 +746,9 @@ async def test_facet_only_considers_first_x_rows():
assert data2["suggested_facets"] == [] assert data2["suggested_facets"] == []
finally: finally:
Facet.suggest_consider = original_suggest_consider Facet.suggest_consider = original_suggest_consider
if ds.executor is not None:
ds.executor.shutdown(wait=True)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()

View file

@ -542,7 +542,9 @@ async def test_execute_write_fn_exception(db):
@pytest.mark.timeout(1) @pytest.mark.timeout(1)
async def test_execute_write_fn_connection_exception(tmpdir, app_client): async def test_execute_write_fn_connection_exception(tmpdir, app_client):
path = str(tmpdir / "immutable.db") path = str(tmpdir / "immutable.db")
sqlite3.connect(path).execute("vacuum") conn = sqlite3.connect(path)
conn.execute("vacuum")
conn.close()
db = Database(app_client.ds, path=path, is_mutable=False) db = Database(app_client.ds, path=path, is_mutable=False)
app_client.ds.add_database(db, name="immutable-db") app_client.ds.add_database(db, name="immutable-db")
@ -746,19 +748,23 @@ async def test_replace_database(tmpdir):
path1 = str(tmpdir / "data1.db") path1 = str(tmpdir / "data1.db")
(tmpdir / "two").mkdir() (tmpdir / "two").mkdir()
path2 = str(tmpdir / "two" / "data1.db") path2 = str(tmpdir / "two" / "data1.db")
sqlite3.connect(path1).executescript( conn1 = sqlite3.connect(path1)
conn1.executescript(
""" """
create table t (id integer primary key); create table t (id integer primary key);
insert into t (id) values (1); insert into t (id) values (1);
insert into t (id) values (2); insert into t (id) values (2);
""" """
) )
sqlite3.connect(path2).executescript( conn1.close()
conn2 = sqlite3.connect(path2)
conn2.executescript(
""" """
create table t (id integer primary key); create table t (id integer primary key);
insert into t (id) values (1); insert into t (id) values (1);
""" """
) )
conn2.close()
datasette = Datasette([path1]) datasette = Datasette([path1])
db = datasette.get_database("data1") db = datasette.get_database("data1")
count = (await db.execute("select count(*) from t")).first()[0] count = (await db.execute("select count(*) from t")).first()[0]

View file

@ -42,7 +42,17 @@ async def perms_ds():
await two.execute_write("create table if not exists t1 (id integer primary key)") await two.execute_write("create table if not exists t1 (id integer primary key)")
# Trigger catalog refresh so allowed_resources() can be called # Trigger catalog refresh so allowed_resources() can be called
await ds.client.get("/") await ds.client.get("/")
return ds try:
yield ds
finally:
# Close databases first (while executor is still running)
for db in ds.databases.values():
db.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -946,24 +956,34 @@ async def test_permissions_in_config(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_actor_endpoint_allows_any_token(): async def test_actor_endpoint_allows_any_token():
ds = Datasette() ds = Datasette()
token = ds.sign( try:
{ token = ds.sign(
"a": "root", {
"a": "root",
"token": "dstok",
"t": int(time.time()),
"_r": {"a": ["debug-menu"]},
},
namespace="token",
)
response = await ds.client.get(
"/-/actor.json", headers={"Authorization": f"Bearer dstok_{token}"}
)
assert response.status_code == 200
assert response.json()["actor"] == {
"id": "root",
"token": "dstok", "token": "dstok",
"t": int(time.time()),
"_r": {"a": ["debug-menu"]}, "_r": {"a": ["debug-menu"]},
}, }
namespace="token", finally:
) # Close databases first (while executor is still running)
response = await ds.client.get( for db in ds.databases.values():
"/-/actor.json", headers={"Authorization": f"Bearer dstok_{token}"} db.close()
) if hasattr(ds, "_internal_database"):
assert response.status_code == 200 ds._internal_database.close()
assert response.json()["actor"] == { # Then shut down executor
"id": "root", if ds.executor is not None:
"token": "dstok", ds.executor.shutdown(wait=True)
"_r": {"a": ["debug-menu"]},
}
@pytest.mark.serial @pytest.mark.serial
@ -1341,9 +1361,19 @@ async def test_actor_restrictions(
) )
async def test_restrictions_allow_action(restrictions, action, resource, expected): async def test_restrictions_allow_action(restrictions, action, resource, expected):
ds = Datasette() ds = Datasette()
await ds.invoke_startup() try:
actual = restrictions_allow_action(ds, restrictions, action, resource) await ds.invoke_startup()
assert actual == expected actual = restrictions_allow_action(ds, restrictions, action, resource)
assert actual == expected
finally:
# Close databases first (while executor is still running)
for db in ds.databases.values():
db.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
# Then shut down executor
if ds.executor is not None:
ds.executor.shutdown(wait=True)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1524,28 +1554,36 @@ async def test_actor_restrictions_cannot_be_overridden_by_config():
} }
ds = Datasette(config=config) ds = Datasette(config=config)
await ds.invoke_startup() try:
db = ds.add_memory_database("test_db") await ds.invoke_startup()
await db.execute_write("create table t1 (id integer primary key)") db = ds.add_memory_database("test_db")
await db.execute_write("create table t2 (id integer primary key)") await db.execute_write("create table t1 (id integer primary key)")
await db.execute_write("create table t2 (id integer primary key)")
# Actor restricted to ONLY t1 (not t2) # Actor restricted to ONLY t1 (not t2)
# Even though config allows t2, restrictions should deny it # Even though config allows t2, restrictions should deny it
actor = {"id": "user", "_r": {"r": {"test_db": {"t1": ["vt"]}}}} actor = {"id": "user", "_r": {"r": {"test_db": {"t1": ["vt"]}}}}
# t1 should be allowed (in restrictions AND config allows) # t1 should be allowed (in restrictions AND config allows)
result = await ds.allowed( result = await ds.allowed(
action="view-table", resource=TableResource("test_db", "t1"), actor=actor action="view-table", resource=TableResource("test_db", "t1"), actor=actor
) )
assert result is True, "t1 should be allowed - in restriction allowlist" assert result is True, "t1 should be allowed - in restriction allowlist"
# t2 should be DENIED (not in restrictions, even though config allows) # t2 should be DENIED (not in restrictions, even though config allows)
result = await ds.allowed( result = await ds.allowed(
action="view-table", resource=TableResource("test_db", "t2"), actor=actor action="view-table", resource=TableResource("test_db", "t2"), actor=actor
) )
assert ( assert (
result is False result is False
), "t2 should be denied - NOT in restriction allowlist, config cannot override" ), "t2 should be denied - NOT in restriction allowlist, config cannot override"
finally:
if ds.executor is not None:
ds.executor.shutdown(wait=True)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1644,29 +1682,42 @@ async def test_permission_check_view_requires_debug_permission():
"""Test that /-/check requires permissions-debug permission""" """Test that /-/check requires permissions-debug permission"""
# Anonymous user should be denied # Anonymous user should be denied
ds = Datasette() ds = Datasette()
response = await ds.client.get("/-/check.json?action=view-instance") ds_with_root = None
assert response.status_code == 403 try:
assert "permissions-debug" in response.text response = await ds.client.get("/-/check.json?action=view-instance")
assert response.status_code == 403
assert "permissions-debug" in response.text
# User without permissions-debug should be denied # User without permissions-debug should be denied
response = await ds.client.get( response = await ds.client.get(
"/-/check.json?action=view-instance", "/-/check.json?action=view-instance",
cookies={"ds_actor": ds.sign({"id": "user"}, "actor")}, cookies={"ds_actor": ds.sign({"id": "user"}, "actor")},
) )
assert response.status_code == 403 assert response.status_code == 403
# Root user should have access (root has all permissions) # Root user should have access (root has all permissions)
ds_with_root = Datasette() ds_with_root = Datasette()
ds_with_root.root_enabled = True ds_with_root.root_enabled = True
root_token = ds_with_root.create_token("root") root_token = ds_with_root.create_token("root")
response = await ds_with_root.client.get( response = await ds_with_root.client.get(
"/-/check.json?action=view-instance", "/-/check.json?action=view-instance",
headers={"Authorization": f"Bearer {root_token}"}, headers={"Authorization": f"Bearer {root_token}"},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["action"] == "view-instance" assert data["action"] == "view-instance"
assert data["allowed"] is True assert data["allowed"] is True
finally:
for ds_obj in [ds, ds_with_root]:
if ds_obj is not None:
# Close databases first (while executor is still running)
for db in ds_obj.databases.values():
db.close()
if hasattr(ds_obj, "_internal_database"):
ds_obj._internal_database.close()
# Then shut down executor
if ds_obj.executor is not None:
ds_obj.executor.shutdown(wait=True)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1686,29 +1737,37 @@ async def test_root_allow_block_with_table_restricted_actor():
"allow": {"id": "admin"}, # Root-level allow block "allow": {"id": "admin"}, # Root-level allow block
} }
) )
await ds.invoke_startup() try:
db = ds.add_memory_database("mydb") await ds.invoke_startup()
await db.execute_write("create table t1 (id integer primary key)") db = ds.add_memory_database("mydb")
await ds.client.get("/") # Trigger catalog refresh await db.execute_write("create table t1 (id integer primary key)")
await ds.client.get("/") # Trigger catalog refresh
# Actor with table-level restrictions only (not global) # Actor with table-level restrictions only (not global)
actor = {"id": "user", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}} actor = {"id": "user", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}}
# The root-level allow: {id: admin} should be processed and deny this user # The root-level allow: {id: admin} should be processed and deny this user
# because they're not "admin", even though they have table restrictions # because they're not "admin", even though they have table restrictions
result = await ds.allowed( result = await ds.allowed(
action="view-table", action="view-table",
resource=TableResource("mydb", "t1"), resource=TableResource("mydb", "t1"),
actor=actor, actor=actor,
) )
# Should be False because root allow: {id: admin} denies non-admin users # Should be False because root allow: {id: admin} denies non-admin users
assert result is False assert result is False
# But admin with same restrictions should be allowed # But admin with same restrictions should be allowed
admin_actor = {"id": "admin", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}} admin_actor = {"id": "admin", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}}
result = await ds.allowed( result = await ds.allowed(
action="view-table", action="view-table",
resource=TableResource("mydb", "t1"), resource=TableResource("mydb", "t1"),
actor=admin_actor, actor=admin_actor,
) )
assert result is True assert result is True
finally:
if ds.executor is not None:
ds.executor.shutdown(wait=True)
for db_obj in ds.databases.values():
db_obj.close()
if hasattr(ds, "_internal_database"):
ds._internal_database.close()

View file

@ -378,9 +378,9 @@ def test_plugins_async_template_function(restore_working_directory):
.select("pre.extra_from_awaitable_function")[0] .select("pre.extra_from_awaitable_function")[0]
.text .text
) )
expected = ( conn = sqlite3.connect(":memory:")
sqlite3.connect(":memory:").execute("select sqlite_version()").fetchone()[0] expected = conn.execute("select sqlite_version()").fetchone()[0]
) conn.close()
assert expected == extra_from_awaitable_function assert expected == extra_from_awaitable_function
@ -424,6 +424,7 @@ def view_names_client(tmp_path_factory):
db_path = str(tmpdir / "fixtures.db") db_path = str(tmpdir / "fixtures.db")
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
conn.executescript(TABLES) conn.executescript(TABLES)
conn.close()
return _TestClient( return _TestClient(
Datasette([db_path], template_dir=str(templates), plugins_dir=str(plugins)) Datasette([db_path], template_dir=str(templates), plugins_dir=str(plugins))
) )

View file

@ -210,6 +210,7 @@ def test_detect_fts(open_quote, close_quote):
assert None is utils.detect_fts(conn, "Test_View") assert None is utils.detect_fts(conn, "Test_View")
assert None is utils.detect_fts(conn, "r") assert None is utils.detect_fts(conn, "r")
assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List") assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List")
conn.close()
@pytest.mark.parametrize("table", ("regular", "has'single quote")) @pytest.mark.parametrize("table", ("regular", "has'single quote"))
@ -226,6 +227,7 @@ def test_detect_fts_different_table_names(table):
conn = utils.sqlite3.connect(":memory:") conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql) conn.executescript(sql)
assert "{table}_fts".format(table=table) == utils.detect_fts(conn, table) assert "{table}_fts".format(table=table) == utils.detect_fts(conn, table)
conn.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -369,6 +371,7 @@ def test_table_columns():
""" """
) )
assert ["id", "name", "bob"] == utils.table_columns(conn, "places") assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
conn.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -443,11 +446,13 @@ def test_check_connection_spatialite_raises():
conn = sqlite3.connect(path) conn = sqlite3.connect(path)
with pytest.raises(utils.SpatialiteConnectionProblem): with pytest.raises(utils.SpatialiteConnectionProblem):
utils.check_connection(conn) utils.check_connection(conn)
conn.close()
def test_check_connection_passes(): def test_check_connection_passes():
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
utils.check_connection(conn) utils.check_connection(conn)
conn.close()
def test_call_with_supported_arguments(): def test_call_with_supported_arguments():
@ -574,10 +579,14 @@ def test_display_actor(actor, expected):
async def test_initial_path_for_datasette(tmp_path_factory, dbs, expected_path): async def test_initial_path_for_datasette(tmp_path_factory, dbs, expected_path):
db_dir = tmp_path_factory.mktemp("dbs") db_dir = tmp_path_factory.mktemp("dbs")
one_table = str(db_dir / "one.db") one_table = str(db_dir / "one.db")
sqlite3.connect(one_table).execute("create table one (id integer primary key)") conn1 = sqlite3.connect(one_table)
conn1.execute("create table one (id integer primary key)")
conn1.close()
two_tables = str(db_dir / "two.db") two_tables = str(db_dir / "two.db")
sqlite3.connect(two_tables).execute("create table two (id integer primary key)") conn2 = sqlite3.connect(two_tables)
sqlite3.connect(two_tables).execute("create table three (id integer primary key)") conn2.execute("create table two (id integer primary key)")
conn2.execute("create table three (id integer primary key)")
conn2.close()
datasette = Datasette( datasette = Datasette(
[{"one_table": one_table, "two_tables": two_tables}[db] for db in dbs] [{"one_table": one_table, "two_tables": two_tables}[db] for db in dbs]
) )

View file

@ -34,7 +34,9 @@ def inner_html(soup):
def has_load_extension(): def has_load_extension():
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
return hasattr(conn, "enable_load_extension") result = hasattr(conn, "enable_load_extension")
conn.close()
return result
def cookie_was_deleted(response, cookie): def cookie_was_deleted(response, cookie):