Apply black to everything, enforce via unit tests (#449)

I've run the black code formatting tool against everything:

    black tests datasette setup.py

I also added a new unit test, in tests/test_black.py, which will fail if the code does not
conform to black's exacting standards.

This unit test only runs on Python 3.6 or higher, because black itself doesn't run on 3.5.
This commit is contained in:
Simon Willison 2019-05-03 22:15:14 -04:00 committed by GitHub
commit 35d6ee2790
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 2758 additions and 2702 deletions

View file

@ -17,10 +17,7 @@ from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import (
DatasetteError,
ureg
)
from .views.base import DatasetteError, ureg
from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView
from .views.special import JsonDataView
@ -39,7 +36,7 @@ from .utils import (
sqlite3,
sqlite_timelimit,
table_columns,
to_css_class
to_css_class,
)
from .inspect import inspect_hash, inspect_views, inspect_tables
from .tracer import capture_traces, trace
@ -51,72 +48,85 @@ app_root = Path(__file__).parent.parent
connections = threading.local()
MEMORY = object()
ConfigOption = collections.namedtuple(
"ConfigOption", ("name", "default", "help")
)
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
CONFIG_OPTIONS = (
ConfigOption("default_page_size", 100, """
Default page size for the table view
""".strip()),
ConfigOption("max_returned_rows", 1000, """
Maximum rows that can be returned from a table or custom query
""".strip()),
ConfigOption("num_sql_threads", 3, """
Number of threads in the thread pool for executing SQLite queries
""".strip()),
ConfigOption("sql_time_limit_ms", 1000, """
Time limit for a SQL query in milliseconds
""".strip()),
ConfigOption("default_facet_size", 30, """
Number of values to return for requested facets
""".strip()),
ConfigOption("facet_time_limit_ms", 200, """
Time limit for calculating a requested facet
""".strip()),
ConfigOption("facet_suggest_time_limit_ms", 50, """
Time limit for calculating a suggested facet
""".strip()),
ConfigOption("hash_urls", False, """
Include DB file contents hash in URLs, for far-future caching
""".strip()),
ConfigOption("allow_facet", True, """
Allow users to specify columns to facet using ?_facet= parameter
""".strip()),
ConfigOption("allow_download", True, """
Allow users to download the original SQLite database files
""".strip()),
ConfigOption("suggest_facets", True, """
Calculate and display suggested facets
""".strip()),
ConfigOption("allow_sql", True, """
Allow arbitrary SQL queries via ?sql= parameter
""".strip()),
ConfigOption("default_cache_ttl", 5, """
Default HTTP cache TTL (used in Cache-Control: max-age= header)
""".strip()),
ConfigOption("default_cache_ttl_hashed", 365 * 24 * 60 * 60, """
Default HTTP cache TTL for hashed URL pages
""".strip()),
ConfigOption("cache_size_kb", 0, """
SQLite cache size in KB (0 == use SQLite default)
""".strip()),
ConfigOption("allow_csv_stream", True, """
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)
""".strip()),
ConfigOption("max_csv_mb", 100, """
Maximum size allowed for CSV export in MB - set 0 to disable this limit
""".strip()),
ConfigOption("truncate_cells_html", 2048, """
Truncate cells longer than this in HTML table view - set 0 to disable
""".strip()),
ConfigOption("force_https_urls", False, """
Force URLs in API output to always use https:// protocol
""".strip()),
ConfigOption("default_page_size", 100, "Default page size for the table view"),
ConfigOption(
"max_returned_rows",
1000,
"Maximum rows that can be returned from a table or custom query",
),
ConfigOption(
"num_sql_threads",
3,
"Number of threads in the thread pool for executing SQLite queries",
),
ConfigOption(
"sql_time_limit_ms", 1000, "Time limit for a SQL query in milliseconds"
),
ConfigOption(
"default_facet_size", 30, "Number of values to return for requested facets"
),
ConfigOption(
"facet_time_limit_ms", 200, "Time limit for calculating a requested facet"
),
ConfigOption(
"facet_suggest_time_limit_ms",
50,
"Time limit for calculating a suggested facet",
),
ConfigOption(
"hash_urls",
False,
"Include DB file contents hash in URLs, for far-future caching",
),
ConfigOption(
"allow_facet",
True,
"Allow users to specify columns to facet using ?_facet= parameter",
),
ConfigOption(
"allow_download",
True,
"Allow users to download the original SQLite database files",
),
ConfigOption("suggest_facets", True, "Calculate and display suggested facets"),
ConfigOption("allow_sql", True, "Allow arbitrary SQL queries via ?sql= parameter"),
ConfigOption(
"default_cache_ttl",
5,
"Default HTTP cache TTL (used in Cache-Control: max-age= header)",
),
ConfigOption(
"default_cache_ttl_hashed",
365 * 24 * 60 * 60,
"Default HTTP cache TTL for hashed URL pages",
),
ConfigOption(
"cache_size_kb", 0, "SQLite cache size in KB (0 == use SQLite default)"
),
ConfigOption(
"allow_csv_stream",
True,
"Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)",
),
ConfigOption(
"max_csv_mb",
100,
"Maximum size allowed for CSV export in MB - set 0 to disable this limit",
),
ConfigOption(
"truncate_cells_html",
2048,
"Truncate cells longer than this in HTML table view - set 0 to disable",
),
ConfigOption(
"force_https_urls",
False,
"Force URLs in API output to always use https:// protocol",
),
)
DEFAULT_CONFIG = {
option.name: option.default
for option in CONFIG_OPTIONS
}
DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
async def favicon(request):
@ -151,11 +161,13 @@ class ConnectedDatabase:
counts = {}
for table in await self.table_names():
try:
table_count = (await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)).rows[0][0]
table_count = (
await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)
).rows[0][0]
counts[table] = table_count
except InterruptedError:
counts[table] = None
@ -175,18 +187,26 @@ class ConnectedDatabase:
return Path(self.path).stem
async def table_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='table'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows]
async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [r[0] for r in (
await self.ds.execute(self.name, """
hidden_tables = [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
""")
).rows]
""",
)
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite
)
@ -205,18 +225,23 @@ class ConnectedDatabase:
] + [
r[0]
for r in (
await self.ds.execute(self.name, """
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
""")
""",
)
).rows
]
# Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t for t in db_metadata["tables"] if db_metadata["tables"][t].get("hidden")
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
# Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
@ -229,7 +254,9 @@ class ConnectedDatabase:
return hidden_tables
async def view_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='view'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows]
def __repr__(self):
@ -245,13 +272,10 @@ class ConnectedDatabase:
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format(
self.name, tags_str
)
return "<ConnectedDatabase: {}{}>".format(self.name, tags_str)
class Datasette:
def __init__(
self,
files,
@ -283,7 +307,9 @@ class Datasette:
path = None
is_memory = True
is_mutable = path not in self.immutables
db = ConnectedDatabase(self, path, is_mutable=is_mutable, is_memory=is_memory)
db = ConnectedDatabase(
self, path, is_mutable=is_mutable, is_memory=is_memory
)
if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db
@ -322,26 +348,24 @@ class Datasette:
def config_dict(self):
# Returns a fully resolved config dictionary, useful for templates
return {
option.name: self.config(option.name)
for option in CONFIG_OPTIONS
}
return {option.name: self.config(option.name) for option in CONFIG_OPTIONS}
def metadata(self, key=None, database=None, table=None, fallback=True):
"""
Looks up metadata, cascading backwards from specified level.
Returns None if metadata value is not found.
"""
assert not (database is None and table is not None), \
"Cannot call metadata() with table= specified but not database="
assert not (
database is None and table is not None
), "Cannot call metadata() with table= specified but not database="
databases = self._metadata.get("databases") or {}
search_list = []
if database is not None:
search_list.append(databases.get(database) or {})
if table is not None:
table_metadata = (
(databases.get(database) or {}).get("tables") or {}
).get(table) or {}
table_metadata = ((databases.get(database) or {}).get("tables") or {}).get(
table
) or {}
search_list.insert(0, table_metadata)
search_list.append(self._metadata)
if not fallback:
@ -359,9 +383,7 @@ class Datasette:
m.update(item)
return m
def plugin_config(
self, plugin_name, database=None, table=None, fallback=True
):
def plugin_config(self, plugin_name, database=None, table=None, fallback=True):
"Return config for plugin, falling back from specified database/table"
plugins = self.metadata(
"plugins", database=database, table=table, fallback=fallback
@ -373,29 +395,19 @@ class Datasette:
def app_css_hash(self):
if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1(
open(
os.path.join(str(app_root), "datasette/static/app.css")
).read().encode(
"utf8"
)
).hexdigest()[
:6
]
open(os.path.join(str(app_root), "datasette/static/app.css"))
.read()
.encode("utf8")
).hexdigest()[:6]
return self._app_css_hash
def get_canned_queries(self, database_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
names = queries.keys()
return [
self.get_canned_query(database_name, name) for name in names
]
return [self.get_canned_query(database_name, name) for name in names]
def get_canned_query(self, database_name, query_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
query = queries.get(query_name)
if query:
if not isinstance(query, dict):
@ -407,7 +419,7 @@ class Datasette:
table_definition_rows = list(
await self.execute(
database_name,
'select sql from sqlite_master where name = :n and type=:t',
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
)
)
@ -416,21 +428,19 @@ class Datasette:
return table_definition_rows[0][0]
def get_view_definition(self, database_name, view):
return self.get_table_definition(database_name, view, 'view')
return self.get_table_definition(database_name, view, "view")
def update_with_inherited_metadata(self, metadata):
# Fills in source/license with defaults, if available
metadata.update(
{
"source": metadata.get("source") or self.metadata("source"),
"source_url": metadata.get("source_url")
or self.metadata("source_url"),
"source_url": metadata.get("source_url") or self.metadata("source_url"),
"license": metadata.get("license") or self.metadata("license"),
"license_url": metadata.get("license_url")
or self.metadata("license_url"),
"about": metadata.get("about") or self.metadata("about"),
"about_url": metadata.get("about_url")
or self.metadata("about_url"),
"about_url": metadata.get("about_url") or self.metadata("about_url"),
}
)
@ -444,7 +454,7 @@ class Datasette:
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
if self.config("cache_size_kb"):
conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb")))
conn.execute("PRAGMA cache_size=-{}".format(self.config("cache_size_kb")))
# pylint: disable=no-member
pm.hook.prepare_connection(conn=conn)
@ -452,7 +462,7 @@ class Datasette:
results = await self.execute(
database,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,)
params=(table,),
)
return bool(results.rows)
@ -463,32 +473,28 @@ class Datasette:
# Find the foreign_key for this column
try:
fk = [
foreign_key for foreign_key in foreign_keys
foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column
][0]
except IndexError:
return {}
label_column = await self.label_column_for_table(database, fk["other_table"])
if not label_column:
return {
(fk["column"], value): str(value)
for value in values
}
return {(fk["column"], value): str(value) for value in values}
labeled_fks = {}
sql = '''
sql = """
select {other_column}, {label_column}
from {other_table}
where {other_column} in ({placeholders})
'''.format(
""".format(
other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(
database, sql, list(set(values))
)
results = await self.execute(database, sql, list(set(values)))
except InterruptedError:
pass
else:
@ -499,7 +505,7 @@ class Datasette:
def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path)
if url.startswith("http://") and self.config("force_https_urls"):
url = "https://" + url[len("http://"):]
url = "https://" + url[len("http://") :]
return url
def inspect(self):
@ -532,10 +538,12 @@ class Datasette:
"file": str(path),
"size": path.stat().st_size,
"views": inspect_views(conn),
"tables": inspect_tables(conn, (self.metadata("databases") or {}).get(name, {}))
"tables": inspect_tables(
conn, (self.metadata("databases") or {}).get(name, {})
),
}
except sqlite3.OperationalError as e:
if (e.args[0] == 'no such module: VirtualSpatialIndex'):
if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
@ -582,7 +590,8 @@ class Datasette:
datasette_version["note"] = self.version_note
return {
"python": {
"version": ".".join(map(str, sys.version_info[:3])), "full": sys.version
"version": ".".join(map(str, sys.version_info[:3])),
"full": sys.version,
},
"datasette": datasette_version,
"sqlite": {
@ -611,10 +620,11 @@ class Datasette:
def table_metadata(self, database, table):
"Fetch table-specific metadata."
return (self.metadata("databases") or {}).get(database, {}).get(
"tables", {}
).get(
table, {}
return (
(self.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {})
)
async def table_columns(self, db_name, table):
@ -628,16 +638,12 @@ class Datasette:
)
async def label_column_for_table(self, db_name, table):
explicit_label_column = (
self.table_metadata(
db_name, table
).get("label_column")
)
explicit_label_column = self.table_metadata(db_name, table).get("label_column")
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.table_columns(db_name, table)
if (column_names and len(column_names) == 2 and "id" in column_names):
if column_names and len(column_names) == 2 and "id" in column_names:
return [c for c in column_names if c != "id"][0]
# Couldn't find a label:
return None
@ -664,9 +670,7 @@ class Datasette:
setattr(connections, db_name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.executor, in_thread
)
return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
async def execute(
self,
@ -701,7 +705,7 @@ class Datasette:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
if e.args == ("interrupted",):
raise InterruptedError(e, sql, params)
if log_sql_errors:
print(
@ -726,7 +730,7 @@ class Datasette:
def register_renderers(self):
""" Register output renderers which output data in custom formats. """
# Built-in renderers
self.renderers['json'] = json_renderer
self.renderers["json"] = json_renderer
# Hooks
hook_renderers = []
@ -737,19 +741,22 @@ class Datasette:
hook_renderers.append(hook)
for renderer in hook_renderers:
self.renderers[renderer['extension']] = renderer['callback']
self.renderers[renderer["extension"]] = renderer["callback"]
def app(self):
class TracingSanic(Sanic):
async def handle_request(self, request, write_callback, stream_callback):
if request.args.get("_trace"):
request["traces"] = []
request["trace_start"] = time.time()
with capture_traces(request["traces"]):
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
else:
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
return res
app = TracingSanic(__name__)
@ -822,15 +829,16 @@ class Datasette:
)
app.add_route(
DatabaseView.as_view(self),
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>"
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
)
app.add_route(
TableView.as_view(self),
r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>",
TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
)
app.add_route(
RowView.as_view(self),
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:(" + renderer_regex + r")?$>",
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
+ renderer_regex
+ r")?$>",
)
self.register_custom_units()
@ -852,7 +860,7 @@ class Datasette:
"duration": time.time() - request["trace_start"],
"queries": request["traces"],
}
if "text/html" in response.content_type and b'</body>' in response.body:
if "text/html" in response.content_type and b"</body>" in response.body:
extra = json.dumps(traces, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
response.body = response.body.replace(b"</body>", extra_html)
@ -908,6 +916,6 @@ class Datasette:
async def setup_db(app, loop):
for dbname, database in self.databases.items():
if not database.is_mutable:
await database.table_counts(limit=60*60*1000)
await database.table_counts(limit=60 * 60 * 1000)
return app