diff --git a/datasette/app.py b/datasette/app.py index b986756f..51953ac0 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -131,7 +131,7 @@ class Datasette: self.cache_headers = cache_headers self.cors = cors self._inspect = inspect_data - self.metadata = metadata or {} + self._metadata = metadata or {} self.sqlite_functions = [] self.sqlite_extensions = sqlite_extensions or [] self.template_dir = template_dir @@ -167,6 +167,38 @@ class Datasette: for option in CONFIG_OPTIONS } + def metadata(self, key=None, database=None, table=None, fallback=True): + """ + Looks up metadata, cascading backwards from specified level. + Returns None if metadata value is not foundself. + """ + assert not (database is None and table is not None), \ + "Cannot call metadata() with table= specified but not database=" + databases = self._metadata.get("databases") or {} + search_list = [] + if database is not None: + search_list.append(databases.get(database) or {}) + if table is not None: + table_metadata = ( + (databases.get(database) or {}).get("tables") or {} + ).get(table) or {} + search_list.insert(0, table_metadata) + search_list.append(self._metadata) + if not fallback: + # No fallback allowed, so just use the first one in the list + search_list = search_list[:1] + if key is not None: + for item in search_list: + if key in item: + return item[key] + return None + else: + # Return the merged list + m = {} + for item in search_list: + m.update(item) + return m + def app_css_hash(self): if not hasattr(self, "_app_css_hash"): self._app_css_hash = hashlib.sha1( @@ -181,19 +213,19 @@ class Datasette: return self._app_css_hash def get_canned_queries(self, database_name): - names = self.metadata.get("databases", {}).get(database_name, {}).get( - "queries", {} - ).keys() + queries = self.metadata( + "queries", database=database_name, fallback=False + ) or {} + names = queries.keys() return [ self.get_canned_query(database_name, name) for name in names ] def get_canned_query(self, database_name, query_name): - query = self.metadata.get("databases", {}).get(database_name, {}).get( - "queries", {} - ).get( - query_name - ) + queries = self.metadata( + "queries", database=database_name, fallback=False + ) or {} + query = queries.get(query_name) if query: if not isinstance(query, dict): query = {"sql": query} @@ -220,7 +252,7 @@ class Datasette: seen_urls = set() for url_or_dict in itertools.chain( itertools.chain.from_iterable(getattr(pm.hook, key)()), - (self.metadata.get(key) or []) + (self.metadata(key) or []) ): if isinstance(url_or_dict, dict): url = url_or_dict["url"] @@ -246,12 +278,12 @@ class Datasette: # Fills in source/license with defaults, if available metadata.update( { - "source": metadata.get("source") or self.metadata.get("source"), + "source": metadata.get("source") or self.metadata("source"), "source_url": metadata.get("source_url") - or self.metadata.get("source_url"), - "license": metadata.get("license") or self.metadata.get("license"), + or self.metadata("source_url"), + "license": metadata.get("license") or self.metadata("license"), "license_url": metadata.get("license_url") - or self.metadata.get("license_url"), + or self.metadata("license_url"), } ) @@ -291,7 +323,7 @@ class Datasette: "hash": inspect_hash(path), "file": str(path), "views": inspect_views(conn), - "tables": inspect_tables(conn, self.metadata.get("databases", {}).get(name, {})) + "tables": inspect_tables(conn, (self.metadata("databases") or {}).get(name, {})) } except sqlite3.OperationalError as e: if (e.args[0] == 'no such module: VirtualSpatialIndex'): @@ -306,7 +338,7 @@ class Datasette: def register_custom_units(self): "Register any custom units defined in the metadata.json with Pint" - for unit in self.metadata.get("custom_units", []): + for unit in self.metadata("custom_units") or []: ureg.define(unit) def versions(self): @@ -469,7 +501,7 @@ class Datasette: "/-/inspect", ) app.add_route( - JsonDataView.as_view(self, "metadata.json", lambda: self.metadata), + JsonDataView.as_view(self, "metadata.json", lambda: self._metadata), "/-/metadata", ) app.add_route( diff --git a/datasette/views/base.py b/datasette/views/base.py index 7484c791..8b72d234 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -74,7 +74,7 @@ class BaseView(RenderMixin): def table_metadata(self, database, table): "Fetch table-specific metadata." - return self.ds.metadata.get("databases", {}).get(database, {}).get( + return (self.ds.metadata("databases") or {}).get(database, {}).get( "tables", {} ).get( table, {} @@ -314,7 +314,7 @@ class BaseView(RenderMixin): end = time.time() data["query_ms"] = (end - start) * 1000 for key in ("source", "source_url", "license", "license_url"): - value = self.ds.metadata.get(key) + value = self.ds.metadata(key) if value: data[key] = value if _format in ("json", "jsono"): diff --git a/datasette/views/database.py b/datasette/views/database.py index 23dc52c6..e2af1ca4 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -18,7 +18,7 @@ class DatabaseView(BaseView): return await self.custom_sql(request, name, hash, sql, _size=_size) info = self.ds.inspect()[name] - metadata = self.ds.metadata.get("databases", {}).get(name, {}) + metadata = (self.ds.metadata("databases") or {}).get(name, {}) self.ds.update_with_inherited_metadata(metadata) tables = list(info["tables"].values()) tables.sort(key=lambda t: (t["hidden"], t["name"])) diff --git a/datasette/views/index.py b/datasette/views/index.py index b529d03b..89888e4f 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -49,7 +49,7 @@ class IndexView(RenderMixin): return self.render( ["index.html"], databases=databases, - metadata=self.ds.metadata, + metadata=self.ds.metadata(), datasette_version=__version__, extra_css_urls=self.ds.extra_css_urls(), extra_js_urls=self.ds.extra_js_urls(), diff --git a/datasette/views/table.py b/datasette/views/table.py index bafe04be..bf6f2355 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -737,7 +737,7 @@ class TableView(RowTableShared): link_column=not is_view, truncate_cells=self.ds.config("truncate_cells_html"), ) - metadata = self.ds.metadata.get("databases", {}).get(name, {}).get( + metadata = (self.ds.metadata("databases") or {}).get(name, {}).get( "tables", {} ).get( table, {} @@ -860,7 +860,9 @@ class RowView(RowTableShared): ), "_rows_and_columns.html", ], - "metadata": self.ds.metadata.get("databases", {}).get(name, {}).get( + "metadata": ( + self.ds.metadata("databases") or {} + ).get(name, {}).get( "tables", {} ).get( table, {}