diff --git a/datasette/app.py b/datasette/app.py index d5994e5e..bc5aaa74 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -116,6 +116,43 @@ async def favicon(request): return response.text("") +class ConnectedDatabase: + def __init__(self, path=None, is_mutable=False, is_memory=False): + self.path = path + self.is_mutable = is_mutable + self.is_memory = is_memory + self.hash = None + self.size = None + if not self.is_mutable: + p = Path(path) + self.hash = inspect_hash(p) + self.size = p.stat().st_size + + @property + def name(self): + if self.is_memory: + return ":memory:" + else: + return Path(self.path).stem + + def __repr__(self): + tags = [] + if self.is_mutable: + tags.append("mutable") + if self.is_memory: + tags.append("memory") + if self.hash: + tags.append("hash={}".format(self.hash)) + if self.size is not None: + tags.append("size={}".format(self.size)) + tags_str = "" + if tags: + tags_str = " ({})".format(", ".join(tags)) + return "".format( + self.name, tags_str + ) + + class Datasette: def __init__( @@ -141,6 +178,18 @@ class Datasette: self.files = [MEMORY] elif memory: self.files = (MEMORY,) + self.files + self.databases = {} + for file in self.files: + path = file + is_memory = False + if file is MEMORY: + path = None + is_memory = True + db = ConnectedDatabase(path, is_mutable=path not in self.immutables, is_memory=is_memory) + if db.name in self.databases: + raise Exception("Multiple files with same stem: {}".format(db.name)) + self.databases[db.name] = db + print(self.databases) self.cache_headers = cache_headers self.cors = cors self._inspect = inspect_data @@ -419,17 +468,17 @@ class Datasette: def in_thread(): conn = getattr(connections, db_name, None) if not conn: - info = self.inspect()[db_name] - if info["file"] == ":memory:": + db = self.databases[db_name] + if db.is_memory: conn = sqlite3.connect(":memory:") else: # mode=ro or immutable=1? - if info["file"] in self.immutables: - qs = "immutable=1" - else: + if db.is_mutable: qs = "mode=ro" + else: + qs = "immutable=1" conn = sqlite3.connect( - "file:{}?{}".format(info["file"], qs), + "file:{}?{}".format(db.path, qs), uri=True, check_same_thread=False, ) diff --git a/datasette/views/base.py b/datasette/views/base.py index e79a46ec..8da51d65 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -156,14 +156,13 @@ class BaseView(RenderMixin): return r async def resolve_db_name(self, request, db_name, **kwargs): - databases = self.ds.inspect() hash = None name = None if "-" in db_name: # Might be name-and-hash, or might just be # a name with a hyphen in it name, hash = db_name.rsplit("-", 1) - if name not in databases: + if name not in self.ds.databases: # Try the whole name name = db_name hash = None @@ -171,11 +170,13 @@ class BaseView(RenderMixin): name = db_name # Verify the hash try: - info = databases[name] + db = self.ds.databases[name] except KeyError: raise NotFound("Database not found: {}".format(name)) - expected = info["hash"][:HASH_LENGTH] + expected = "000" + if db.hash is not None: + expected = db.hash[:HASH_LENGTH] correct_hash_provided = (expected == hash) if not correct_hash_provided: diff --git a/tests/fixtures.py b/tests/fixtures.py index 81432e30..c3883fe7 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -29,6 +29,7 @@ def make_app_client( cors=False, config=None, filename="fixtures.db", + is_immutable=False, ): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, filename) @@ -48,7 +49,8 @@ def make_app_client( } ) ds = Datasette( - [filepath], + [] if is_immutable else [filepath], + immutables=[filepath] if is_immutable else [], cors=cors, metadata=METADATA, plugins_dir=plugins_dir, @@ -76,8 +78,8 @@ def app_client_no_files(): @pytest.fixture(scope="session") def app_client_with_hash(): yield from make_app_client(config={ - 'hash_urls': True - }) + 'hash_urls': True, + }, is_immutable=True) @pytest.fixture(scope='session') diff --git a/tests/test_api.py b/tests/test_api.py index b92b9ffb..6fdcb115 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1317,10 +1317,10 @@ def test_ttl_parameter(app_client, path, expected_cache_control): ("/fixtures/facetable.json?_hash=1", "/fixtures-HASH/facetable.json"), ("/fixtures/facetable.json?city_id=1&_hash=1", "/fixtures-HASH/facetable.json?city_id=1"), ]) -def test_hash_parameter(app_client, path, expected_redirect): +def test_hash_parameter(app_client_with_hash, path, expected_redirect): # First get the current hash for the fixtures database - current_hash = app_client.get("/-/inspect.json").json["fixtures"]["hash"][:7] - response = app_client.get(path, allow_redirects=False) + current_hash = app_client_with_hash.get("/-/inspect.json").json["fixtures"]["hash"][:7] + response = app_client_with_hash.get(path, allow_redirects=False) assert response.status == 302 location = response.headers["Location"] assert expected_redirect.replace("HASH", current_hash) == location