Refactor to use new datasatte.config(key) method

This commit is contained in:
Simon Willison 2018-08-11 13:06:45 -07:00
commit 2189be1440
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
4 changed files with 33 additions and 23 deletions

View file

@ -137,14 +137,14 @@ class Datasette:
self.template_dir = template_dir self.template_dir = template_dir
self.plugins_dir = plugins_dir self.plugins_dir = plugins_dir
self.static_mounts = static_mounts or [] self.static_mounts = static_mounts or []
self.config = dict(DEFAULT_CONFIG, **(config or {})) self._config = dict(DEFAULT_CONFIG, **(config or {}))
self.version_note = version_note self.version_note = version_note
self.executor = futures.ThreadPoolExecutor( self.executor = futures.ThreadPoolExecutor(
max_workers=self.config["num_sql_threads"] max_workers=self.config("num_sql_threads")
) )
self.max_returned_rows = self.config["max_returned_rows"] self.max_returned_rows = self.config("max_returned_rows")
self.sql_time_limit_ms = self.config["sql_time_limit_ms"] self.sql_time_limit_ms = self.config("sql_time_limit_ms")
self.page_size = self.config["default_page_size"] self.page_size = self.config("default_page_size")
# Execute plugins in constructor, to ensure they are available # Execute plugins in constructor, to ensure they are available
# when the rest of `datasette inspect` executes # when the rest of `datasette inspect` executes
if self.plugins_dir: if self.plugins_dir:
@ -157,6 +157,16 @@ class Datasette:
# Plugin already registered # Plugin already registered
pass pass
def config(self, key):
return self._config.get(key, None)
def config_dict(self):
# Returns a fully resolved config dictionary, useful for templates
return {
option.name: self.config(option.name)
for option in CONFIG_OPTIONS
}
def app_css_hash(self): def app_css_hash(self):
if not hasattr(self, "_app_css_hash"): if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1( self._app_css_hash = hashlib.sha1(
@ -254,8 +264,8 @@ class Datasette:
conn.enable_load_extension(True) conn.enable_load_extension(True)
for extension in self.sqlite_extensions: for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension)) conn.execute("SELECT load_extension('{}')".format(extension))
if self.config["cache_size_kb"]: if self.config("cache_size_kb"):
conn.execute('PRAGMA cache_size=-{}'.format(self.config["cache_size_kb"])) conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb")))
pm.hook.prepare_connection(conn=conn) pm.hook.prepare_connection(conn=conn)
def table_exists(self, database, table): def table_exists(self, database, table):
@ -471,7 +481,7 @@ class Datasette:
"/-/plugins<as_format:(\.json)?$>", "/-/plugins<as_format:(\.json)?$>",
) )
app.add_route( app.add_route(
JsonDataView.as_view(self, "config.json", lambda: self.config), JsonDataView.as_view(self, "config.json", lambda: self._config),
"/-/config<as_format:(\.json)?$>", "/-/config<as_format:(\.json)?$>",
) )
app.add_route( app.add_route(

View file

@ -149,7 +149,7 @@ class BaseView(RenderMixin):
def absolute_url(self, request, path): def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path) url = urllib.parse.urljoin(request.url, path)
if url.startswith("http://") and self.ds.config["force_https_urls"]: if url.startswith("http://") and self.ds.config("force_https_urls"):
url = "https://" + url[len("http://"):] url = "https://" + url[len("http://"):]
return url return url
@ -167,7 +167,7 @@ class BaseView(RenderMixin):
stream = request.args.get("_stream") stream = request.args.get("_stream")
if stream: if stream:
# Some quick sanity checks # Some quick sanity checks
if not self.ds.config["allow_csv_stream"]: if not self.ds.config("allow_csv_stream"):
raise DatasetteError("CSV streaming is disabled", status=400) raise DatasetteError("CSV streaming is disabled", status=400)
if request.args.get("_next"): if request.args.get("_next"):
raise DatasetteError( raise DatasetteError(
@ -205,7 +205,7 @@ class BaseView(RenderMixin):
async def stream_fn(r): async def stream_fn(r):
nonlocal data nonlocal data
writer = csv.writer(LimitedWriter(r, self.ds.config["max_csv_mb"])) writer = csv.writer(LimitedWriter(r, self.ds.config("max_csv_mb")))
first = True first = True
next = None next = None
while first or (next and stream): while first or (next and stream):
@ -426,7 +426,7 @@ class BaseView(RenderMixin):
"extra_css_urls": self.ds.extra_css_urls(), "extra_css_urls": self.ds.extra_css_urls(),
"extra_js_urls": self.ds.extra_js_urls(), "extra_js_urls": self.ds.extra_js_urls(),
"datasette_version": __version__, "datasette_version": __version__,
"config": self.ds.config, "config": self.ds.config_dict(),
} }
} }
if "metadata" not in context: if "metadata" not in context:
@ -437,7 +437,7 @@ class BaseView(RenderMixin):
if self.ds.cache_headers: if self.ds.cache_headers:
ttl = request.args.get("_ttl", None) ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit(): if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"] ttl = self.ds.config("default_cache_ttl")
else: else:
ttl = int(ttl) ttl = int(ttl)
if ttl == 0: if ttl == 0:
@ -517,7 +517,7 @@ class BaseView(RenderMixin):
"editable": editable, "editable": editable,
"canned_query": canned_query, "canned_query": canned_query,
"metadata": metadata, "metadata": metadata,
"config": self.ds.config, "config": self.ds.config_dict(),
} }
return { return {

View file

@ -11,7 +11,7 @@ class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False, _size=None): async def data(self, request, name, hash, default_labels=False, _size=None):
if request.args.get("sql"): if request.args.get("sql"):
if not self.ds.config["allow_sql"]: if not self.ds.config("allow_sql"):
raise DatasetteError("sql= is not allowed", status=400) raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql") sql = request.raw_args.pop("sql")
validate_sql_select(sql) validate_sql_select(sql)
@ -41,7 +41,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView): class DatabaseDownload(BaseView):
async def view_get(self, request, name, hash, **kwargs): async def view_get(self, request, name, hash, **kwargs):
if not self.ds.config["allow_download"]: if not self.ds.config("allow_download"):
raise DatasetteError("Database download is forbidden", status=403) raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"] filepath = self.ds.inspect()[name]["file"]
return await response.file_stream( return await response.file_stream(

View file

@ -513,10 +513,10 @@ class TableView(RowTableShared):
) )
# facets support # facets support
facet_size = self.ds.config["default_facet_size"] facet_size = self.ds.config("default_facet_size")
metadata_facets = table_metadata.get("facets", []) metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:] facets = metadata_facets[:]
if request.args.get("_facet") and not self.ds.config["allow_facet"]: if request.args.get("_facet") and not self.ds.config("allow_facet"):
raise DatasetteError("_facet= is not allowed", status=400) raise DatasetteError("_facet= is not allowed", status=400)
try: try:
facets.extend(request.args["_facet"]) facets.extend(request.args["_facet"])
@ -541,7 +541,7 @@ class TableView(RowTableShared):
facet_rows_results = await self.ds.execute( facet_rows_results = await self.ds.execute(
name, facet_sql, params, name, facet_sql, params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config["facet_time_limit_ms"], custom_time_limit=self.ds.config("facet_time_limit_ms"),
) )
facet_results_values = [] facet_results_values = []
facet_results[column] = { facet_results[column] = {
@ -674,13 +674,13 @@ class TableView(RowTableShared):
# Detect suggested facets # Detect suggested facets
suggested_facets = [] suggested_facets = []
if self.ds.config["suggest_facets"] and self.ds.config["allow_facet"]: if self.ds.config("suggest_facets") and self.ds.config("allow_facet"):
for facet_column in columns: for facet_column in columns:
if facet_column in facets: if facet_column in facets:
continue continue
if _next: if _next:
continue continue
if not self.ds.config["suggest_facets"]: if not self.ds.config("suggest_facets"):
continue continue
suggested_facet_sql = ''' suggested_facet_sql = '''
select distinct {column} {from_sql} select distinct {column} {from_sql}
@ -697,7 +697,7 @@ class TableView(RowTableShared):
distinct_values = await self.ds.execute( distinct_values = await self.ds.execute(
name, suggested_facet_sql, from_sql_params, name, suggested_facet_sql, from_sql_params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"], custom_time_limit=self.ds.config("facet_suggest_time_limit_ms"),
) )
num_distinct_values = len(distinct_values) num_distinct_values = len(distinct_values)
if ( if (
@ -735,7 +735,7 @@ class TableView(RowTableShared):
results.description, results.description,
rows, rows,
link_column=not is_view, link_column=not is_view,
truncate_cells=self.ds.config["truncate_cells_html"], truncate_cells=self.ds.config("truncate_cells_html"),
) )
metadata = self.ds.metadata.get("databases", {}).get(name, {}).get( metadata = self.ds.metadata.get("databases", {}).get(name, {}).get(
"tables", {} "tables", {}