From 7a6654a253dee243518dc542ce4c06dbb0d0801d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 19 Mar 2022 17:11:17 -0700 Subject: [PATCH] Databases can now have a .route separate from their .name, refs #1668 --- datasette/app.py | 13 ++++++-- datasette/database.py | 1 + datasette/views/base.py | 12 +++++-- datasette/views/database.py | 18 ++++++----- datasette/views/table.py | 29 ++++++++++++----- docs/internals.rst | 11 ++++--- tests/test_internals_datasette.py | 1 + tests/test_routes.py | 52 ++++++++++++++++++++++++++++++- 8 files changed, 111 insertions(+), 26 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index edef34e9..5c8101a3 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -388,13 +388,18 @@ class Datasette: def unsign(self, signed, namespace="default"): return URLSafeSerializer(self._secret, namespace).loads(signed) - def get_database(self, name=None): + def get_database(self, name=None, route=None): + if route is not None: + matches = [db for db in self.databases.values() if db.route == route] + if not matches: + raise KeyError + return matches[0] if name is None: - # Return first no-_schemas database + # Return first database that isn't "_internal" name = [key for key in self.databases.keys() if key != "_internal"][0] return self.databases[name] - def add_database(self, db, name=None): + def add_database(self, db, name=None, route=None): new_databases = self.databases.copy() if name is None: # Pick a unique name for this database @@ -407,6 +412,7 @@ class Datasette: name = "{}_{}".format(suggestion, i) i += 1 db.name = name + db.route = route or name new_databases[name] = db # don't mutate! that causes race conditions with live import self.databases = new_databases @@ -693,6 +699,7 @@ class Datasette: return [ { "name": d.name, + "route": d.route, "path": d.path, "size": d.size, "is_mutable": d.is_mutable, diff --git a/datasette/database.py b/datasette/database.py index 6ce87215..ba594a8c 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -31,6 +31,7 @@ class Database: self, ds, path=None, is_mutable=False, is_memory=False, memory_name=None ): self.name = None + self.route = None self.ds = ds self.path = path self.is_mutable = is_mutable diff --git a/datasette/views/base.py b/datasette/views/base.py index 24e97d95..afa9eaa6 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -371,13 +371,19 @@ class DataView(BaseView): return AsgiStream(stream_fn, headers=headers, content_type=content_type) async def get(self, request): - db_name = request.url_vars["database"] - database = tilde_decode(db_name) + database_route = tilde_decode(request.url_vars["database"]) + + try: + db = self.ds.get_database(route=database_route) + except KeyError: + raise NotFound("Database not found: {}".format(database_route)) + database = db.name + _format = request.url_vars["format"] data_kwargs = {} if _format == "csv": - return await self.as_csv(request, database) + return await self.as_csv(request, database_route) if _format is None: # HTML views default to expanding all foreign key labels diff --git a/datasette/views/database.py b/datasette/views/database.py index 93bd1011..2563c5b2 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -32,7 +32,13 @@ class DatabaseView(DataView): name = "database" async def data(self, request, default_labels=False, _size=None): - database = tilde_decode(request.url_vars["database"]) + database_route = tilde_decode(request.url_vars["database"]) + try: + db = self.ds.get_database(route=database_route) + except KeyError: + raise NotFound("Database not found: {}".format(database_route)) + database = db.name + await self.check_permissions( request, [ @@ -50,11 +56,6 @@ class DatabaseView(DataView): request, sql, _size=_size, metadata=metadata ) - try: - db = self.ds.databases[database] - except KeyError: - raise NotFound("Database not found: {}".format(database)) - table_counts = await db.table_counts(5) hidden_table_names = set(await db.hidden_table_names()) all_foreign_keys = await db.get_all_foreign_keys() @@ -171,9 +172,10 @@ class DatabaseDownload(DataView): "view-instance", ], ) - if database not in self.ds.databases: + try: + db = self.ds.get_database(route=database) + except KeyError: raise DatasetteError("Invalid database", status=404) - db = self.ds.databases[database] if db.is_memory: raise DatasetteError("Cannot download in-memory databases", status=404) if not self.ds.setting("allow_download") or db.is_mutable: diff --git a/datasette/views/table.py b/datasette/views/table.py index ea4f24b7..7fa1da3a 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -272,10 +272,15 @@ class TableView(RowTableShared): name = "table" async def post(self, request): - db_name = tilde_decode(request.url_vars["database"]) + database_route = tilde_decode(request.url_vars["database"]) + try: + db = self.ds.get_database(route=database_route) + except KeyError: + raise NotFound("Database not found: {}".format(database_route)) + database = db.name table = tilde_decode(request.url_vars["table"]) # Handle POST to a canned query - canned_query = await self.ds.get_canned_query(db_name, table, request.actor) + canned_query = await self.ds.get_canned_query(database, table, request.actor) assert canned_query, "You may only POST to a canned query" return await QueryView(self.ds).data( request, @@ -327,12 +332,13 @@ class TableView(RowTableShared): _next=None, _size=None, ): - database = tilde_decode(request.url_vars["database"]) + database_route = tilde_decode(request.url_vars["database"]) table = tilde_decode(request.url_vars["table"]) try: - db = self.ds.databases[database] + db = self.ds.get_database(route=database_route) except KeyError: - raise NotFound("Database not found: {}".format(database)) + raise NotFound("Database not found: {}".format(database_route)) + database = db.name # If this is a canned query, not a table, then dispatch to QueryView instead canned_query = await self.ds.get_canned_query(database, table, request.actor) @@ -938,8 +944,13 @@ class RowView(RowTableShared): name = "row" async def data(self, request, default_labels=False): - database = tilde_decode(request.url_vars["database"]) + database_route = tilde_decode(request.url_vars["database"]) table = tilde_decode(request.url_vars["table"]) + try: + db = self.ds.get_database(route=database_route) + except KeyError: + raise NotFound("Database not found: {}".format(database_route)) + database = db.name await self.check_permissions( request, [ @@ -949,7 +960,11 @@ class RowView(RowTableShared): ], ) pk_values = urlsafe_components(request.url_vars["pks"]) - db = self.ds.databases[database] + try: + db = self.ds.get_database(route=database_route) + except KeyError: + raise NotFound("Database not found: {}".format(database_route)) + database = db.name sql, params, pks = await _sql_params_pks(db, table, pk_values) results = await db.execute(sql, params, truncate=True) columns = [r[0] for r in results.description] diff --git a/docs/internals.rst b/docs/internals.rst index 117cb95c..323256c7 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -307,14 +307,17 @@ Returns the specified database object. Raises a ``KeyError`` if the database doe .. _datasette_add_database: -.add_database(db, name=None) ----------------------------- +.add_database(db, name=None, route=None) +---------------------------------------- ``db`` - datasette.database.Database instance The database to be attached. ``name`` - string, optional - The name to be used for this database - this will be used in the URL path, e.g. ``/dbname``. If not specified Datasette will pick one based on the filename or memory name. + The name to be used for this database . If not specified Datasette will pick one based on the filename or memory name. + +``route`` - string, optional + This will be used in the URL path. If not specified, it will default to the same thing as the ``name``. The ``datasette.add_database(db)`` method lets you add a new database to the current Datasette instance. @@ -371,7 +374,7 @@ Using either of these pattern will result in the in-memory database being served ``name`` - string The name of the database to be removed. -This removes a database that has been previously added. ``name=`` is the unique name of that database, used in its URL path. +This removes a database that has been previously added. ``name=`` is the unique name of that database. .. _datasette_sign: diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index adf84be9..cc200a2d 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -55,6 +55,7 @@ async def test_datasette_constructor(): assert databases == [ { "name": "_memory", + "route": "_memory", "path": None, "size": 0, "is_mutable": False, diff --git a/tests/test_routes.py b/tests/test_routes.py index 1fa55018..dd3bc644 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,6 +1,7 @@ -from datasette.app import Datasette +from datasette.app import Datasette, Database from datasette.utils import resolve_routes import pytest +import pytest_asyncio @pytest.fixture(scope="session") @@ -53,3 +54,52 @@ def test_routes(routes, path, expected_class, expected_matches): else: assert view.view_class.__name__ == expected_class assert match.groupdict() == expected_matches + + +@pytest_asyncio.fixture +async def ds_with_route(): + ds = Datasette() + ds.remove_database("_memory") + db = Database(ds, is_memory=True, memory_name="route-name-db") + ds.add_database(db, name="name", route="route-name") + await db.execute_write_script( + """ + create table if not exists t (id integer primary key); + insert or replace into t (id) values (1); + """ + ) + return ds + + +@pytest.mark.asyncio +async def test_db_with_route_databases(ds_with_route): + response = await ds_with_route.client.get("/-/databases.json") + assert response.json()[0] == { + "name": "name", + "route": "route-name", + "path": None, + "size": 0, + "is_mutable": True, + "is_memory": True, + "hash": None, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "path,expected_status", + ( + ("/", 200), + ("/name", 404), + ("/name/t", 404), + ("/name/t/1", 404), + ("/route-name", 200), + ("/route-name/t", 200), + ("/route-name/t/1", 200), + ), +) +async def test_db_with_route_that_does_not_match_name( + ds_with_route, path, expected_status +): + response = await ds_with_route.client.get(path) + assert response.status_code == expected_status