From b647b5efc29300f715ba656e41b0591f342938e1 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 18 Oct 2019 15:51:07 -0700 Subject: [PATCH] Fix for /foo v.s. /foo-bar issue, closes #597 Pull request #599 --- datasette/views/base.py | 16 ++++++++-------- tests/fixtures.py | 7 +++++++ tests/test_api.py | 18 ++++++++++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/datasette/views/base.py b/datasette/views/base.py index db1d69d9..219630af 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -193,14 +193,14 @@ class DataView(BaseView): async def resolve_db_name(self, request, db_name, **kwargs): hash = None name = None - if "-" in db_name: - # Might be name-and-hash, or might just be - # a name with a hyphen in it - name, hash = db_name.rsplit("-", 1) - if name not in self.ds.databases: - # Try the whole name - name = db_name - hash = None + if db_name not in self.ds.databases and "-" in db_name: + # No matching DB found, maybe it's a name-hash? + name_bit, hash_bit = db_name.rsplit("-", 1) + if name_bit not in self.ds.databases: + raise NotFound("Database not found: {}".format(name)) + else: + name = name_bit + hash = hash_bit else: name = db_name # Verify the hash diff --git a/tests/fixtures.py b/tests/fixtures.py index dac28dc0..a4c32f36 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -178,6 +178,13 @@ def app_client_two_attached_databases(): ) +@pytest.fixture(scope="session") +def app_client_conflicting_database_names(): + yield from make_app_client( + extra_databases={"foo.db": EXTRA_DATABASE_SQL, "foo-bar.db": EXTRA_DATABASE_SQL} + ) + + @pytest.fixture(scope="session") def app_client_two_attached_databases_one_immutable(): yield from make_app_client( diff --git a/tests/test_api.py b/tests/test_api.py index cc00b780..826c00f3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,6 +7,7 @@ from .fixtures import ( # noqa app_client_larger_cache_size, app_client_returned_rows_matches_page_size, app_client_two_attached_databases_one_immutable, + app_client_conflicting_database_names, app_client_with_cors, app_client_with_dot, generate_compound_rows, @@ -1652,3 +1653,20 @@ def test_cors(app_client_with_cors, path, status_code): response = app_client_with_cors.get(path) assert response.status == status_code assert "*" == response.headers["Access-Control-Allow-Origin"] + + +def test_common_prefix_database_names(app_client_conflicting_database_names): + # https://github.com/simonw/datasette/issues/597 + assert ["fixtures", "foo", "foo-bar"] == [ + d["name"] + for d in json.loads( + app_client_conflicting_database_names.get("/-/databases.json").body.decode( + "utf8" + ) + ) + ] + for db_name, path in (("foo", "/foo.json"), ("foo-bar", "/foo-bar.json")): + data = json.loads( + app_client_conflicting_database_names.get(path).body.decode("utf8") + ) + assert db_name == data["database"]