From 8919f99c2f7f245aca7f94bd53d5ac9d04aa42b5 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 22 Dec 2020 12:04:18 -0800 Subject: [PATCH] Improved .add_database() method design Closes #1155 - _internal now has a sensible name Closes #509 - Support opening multiple databases with the same stem --- datasette/app.py | 34 +++++++++++++++++--------- datasette/database.py | 42 +++++++++++++++++--------------- docs/internals.rst | 29 ++++++++++++++-------- tests/test_cli.py | 15 ++++++++++++ tests/test_internals_database.py | 12 ++++----- 5 files changed, 86 insertions(+), 46 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index f995e79d..ad3ba07e 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -218,18 +218,18 @@ class Datasette: self.immutables = set(immutables or []) self.databases = collections.OrderedDict() if memory or not self.files: - self.add_database(":memory:", Database(self, ":memory:", is_memory=True)) + self.add_database(Database(self, is_memory=True), name=":memory:") # memory_name is a random string so that each Datasette instance gets its own # unique in-memory named database - otherwise unit tests can fail with weird # errors when different instances accidentally share an in-memory database - self.add_database("_internal", Database(self, memory_name=secrets.token_hex())) - self._interna_db_created = False + self.add_database( + Database(self, memory_name=secrets.token_hex()), name="_internal" + ) + self.internal_db_created = False for file in self.files: - path = file - db = Database(self, path, is_mutable=path not in self.immutables) - if db.name in self.databases: - raise Exception(f"Multiple files with same stem: {db.name}") - self.add_database(db.name, db) + self.add_database( + Database(self, file, is_mutable=file not in self.immutables) + ) self.cache_headers = cache_headers self.cors = cors metadata_files = [] @@ -325,9 +325,9 @@ class Datasette: async def refresh_schemas(self): internal_db = self.databases["_internal"] - if not self._interna_db_created: + if not self.internal_db_created: await init_internal_db(internal_db) - self._interna_db_created = True + self.internal_db_created = True current_schema_versions = { row["database_name"]: row["schema_version"] @@ -370,8 +370,20 @@ class Datasette: name = [key for key in self.databases.keys() if key != "_internal"][0] return self.databases[name] - def add_database(self, name, db): + def add_database(self, db, name=None): + if name is None: + # Pick a unique name for this database + suggestion = db.suggest_name() + name = suggestion + else: + suggestion = name + i = 2 + while name in self.databases: + name = "{}_{}".format(suggestion, i) + i += 1 + db.name = name self.databases[name] = db + return db def remove_database(self, name): self.databases.pop(name) diff --git a/datasette/database.py b/datasette/database.py index a977b362..cda36e6e 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -27,30 +27,44 @@ class Database: def __init__( self, ds, path=None, is_mutable=False, is_memory=False, memory_name=None ): + self.name = None self.ds = ds self.path = path self.is_mutable = is_mutable self.is_memory = is_memory self.memory_name = memory_name if memory_name is not None: - self.path = memory_name self.is_memory = True self.is_mutable = True self.hash = None self.cached_size = None - self.cached_table_counts = None + self._cached_table_counts = None self._write_thread = None self._write_queue = None if not self.is_mutable and not self.is_memory: p = Path(path) self.hash = inspect_hash(p) self.cached_size = p.stat().st_size - # Maybe use self.ds.inspect_data to populate cached_table_counts - if self.ds.inspect_data and self.ds.inspect_data.get(self.name): - self.cached_table_counts = { - key: value["count"] - for key, value in self.ds.inspect_data[self.name]["tables"].items() - } + + @property + def cached_table_counts(self): + if self._cached_table_counts is not None: + return self._cached_table_counts + # Maybe use self.ds.inspect_data to populate cached_table_counts + if self.ds.inspect_data and self.ds.inspect_data.get(self.name): + self._cached_table_counts = { + key: value["count"] + for key, value in self.ds.inspect_data[self.name]["tables"].items() + } + return self._cached_table_counts + + def suggest_name(self): + if self.path: + return Path(self.path).stem + elif self.memory_name: + return self.memory_name + else: + return "db" def connect(self, write=False): if self.memory_name: @@ -220,7 +234,7 @@ class Database: except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError): counts[table] = None if not self.is_mutable: - self.cached_table_counts = counts + self._cached_table_counts = counts return counts @property @@ -229,16 +243,6 @@ class Database: return None return Path(self.path).stat().st_mtime_ns - @property - def name(self): - if self.is_memory: - if self.memory_name: - return ":memory:{}".format(self.memory_name) - else: - return ":memory:" - else: - return Path(self.path).stem - async def table_exists(self, table): results = await self.execute( "select 1 from sqlite_master where type='table' and name=?", params=(table,) diff --git a/docs/internals.rst b/docs/internals.rst index b68a1d8a..05cb8bd7 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -245,16 +245,16 @@ Returns the specified database object. Raises a ``KeyError`` if the database doe .. _datasette_add_database: -.add_database(name, db) ------------------------ - -``name`` - string - The unique name to use for this database. Also used in the URL. +.add_database(db, name=None) +---------------------------- ``db`` - datasette.database.Database instance The database to be attached. -The ``datasette.add_database(name, db)`` method lets you add a new database to the current Datasette instance. This database will then be served at URL path that matches the ``name`` parameter, e.g. ``/mynewdb/``. +``name`` - string, optional + The name to be used for this database - this will be used in the URL path, e.g. ``/dbname``. If not specified Datasette will pick one based on the filename or memory name. + +The ``datasette.add_database(db)`` method lets you add a new database to the current Datasette instance. The ``db`` parameter should be an instance of the ``datasette.database.Database`` class. For example: @@ -262,13 +262,13 @@ The ``db`` parameter should be an instance of the ``datasette.database.Database` from datasette.database import Database - datasette.add_database("my-new-database", Database( + datasette.add_database(Database( datasette, path="path/to/my-new-database.db", is_mutable=True )) -This will add a mutable database from the provided file path. +This will add a mutable database and serve it at ``/my-new-database``. To create a shared in-memory database named ``statistics``, use the following: @@ -276,11 +276,20 @@ To create a shared in-memory database named ``statistics``, use the following: from datasette.database import Database - datasette.add_database("statistics", Database( + datasette.add_database(Database( datasette, memory_name="statistics" )) +This database will be served at ``/statistics``. + +``.add_database()`` returns the Database instance, with its name set as the ``database.name`` attribute. Any time you are working with a newly added database you should use the return value of ``.add_database()``, for example: + +.. code-block:: python + + db = datasette.add_database(Database(datasette, memory_name="statistics")) + await db.execute_write("CREATE TABLE foo(id integer primary key)", block=True) + .. _datasette_remove_database: .remove_database(name) @@ -289,7 +298,7 @@ To create a shared in-memory database named ``statistics``, use the following: ``name`` - string The name of the database to be removed. -This removes a database that has been previously added. ``name=`` is the unique name of that database, also used in the URL for it. +This removes a database that has been previously added. ``name=`` is the unique name of that database, used in its URL path. .. _datasette_sign: diff --git a/tests/test_cli.py b/tests/test_cli.py index 3f6b1840..ff46d76f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,6 +8,7 @@ import asyncio from datasette.plugins import DEFAULT_PLUGINS from datasette.cli import cli, serve from datasette.version import __version__ +from datasette.utils.sqlite import sqlite3 from click.testing import CliRunner import io import json @@ -240,3 +241,17 @@ def test_serve_create(ensure_eventloop, tmpdir): "hash": None, }.items() <= databases[0].items() assert db_path.exists() + + +def test_serve_duplicate_database_names(ensure_eventloop, tmpdir): + runner = CliRunner() + db_1_path = str(tmpdir / "db.db") + nested = tmpdir / "nested" + nested.mkdir() + db_2_path = str(tmpdir / "nested" / "db.db") + for path in (db_1_path, db_2_path): + sqlite3.connect(path).execute("vacuum") + result = runner.invoke(cli, [db_1_path, db_2_path, "--get", "/-/databases.json"]) + assert result.exit_code == 0, result.output + databases = json.loads(result.output) + assert {db["name"] for db in databases} == {"db", "db_2"} diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index dc1af48c..7eff9f7e 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -439,7 +439,7 @@ async def test_execute_write_fn_connection_exception(tmpdir, app_client): path = str(tmpdir / "immutable.db") sqlite3.connect(path).execute("vacuum") db = Database(app_client.ds, path=path, is_mutable=False) - app_client.ds.add_database("immutable-db", db) + app_client.ds.add_database(db, name="immutable-db") def write_fn(conn): assert False @@ -469,10 +469,10 @@ def test_is_mutable(app_client): @pytest.mark.asyncio async def test_database_memory_name(app_client): ds = app_client.ds - foo1 = Database(ds, memory_name="foo") - foo2 = Database(ds, memory_name="foo") - bar1 = Database(ds, memory_name="bar") - bar2 = Database(ds, memory_name="bar") + foo1 = ds.add_database(Database(ds, memory_name="foo")) + foo2 = ds.add_database(Database(ds, memory_name="foo")) + bar1 = ds.add_database(Database(ds, memory_name="bar")) + bar2 = ds.add_database(Database(ds, memory_name="bar")) for db in (foo1, foo2, bar1, bar2): table_names = await db.table_names() assert table_names == [] @@ -487,7 +487,7 @@ async def test_database_memory_name(app_client): @pytest.mark.asyncio async def test_in_memory_databases_forbid_writes(app_client): ds = app_client.ds - db = Database(ds, memory_name="test") + db = ds.add_database(Database(ds, memory_name="test")) with pytest.raises(sqlite3.OperationalError): await db.execute("create table foo (t text)") assert await db.table_names() == []