Apply black to everything, enforce via unit tests (#449)

I've run the black code formatting tool against everything:

    black tests datasette setup.py

I also added a new unit test, in tests/test_black.py, which will fail if the code does not
conform to black's exacting standards.

This unit test only runs on Python 3.6 or higher, because black itself doesn't run on 3.5.
This commit is contained in:
Simon Willison 2019-05-03 22:15:14 -04:00 committed by GitHub
commit 35d6ee2790
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 2758 additions and 2702 deletions

View file

@ -33,8 +33,15 @@ HASH_LENGTH = 7
class DatasetteError(Exception):
def __init__(self, message, title=None, error_dict=None, status=500, template=None, messagge_is_html=False):
def __init__(
self,
message,
title=None,
error_dict=None,
status=500,
template=None,
messagge_is_html=False,
):
self.message = message
self.title = title
self.error_dict = error_dict or {}
@ -43,18 +50,19 @@ class DatasetteError(Exception):
class RenderMixin(HTTPMethodView):
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
for url_or_dict in itertools.chain(
itertools.chain.from_iterable(getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds
)),
(self.ds.metadata(key) or [])
itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds,
)
),
(self.ds.metadata(key) or []),
):
if isinstance(url_or_dict, dict):
url = url_or_dict["url"]
@ -73,14 +81,12 @@ class RenderMixin(HTTPMethodView):
def database_url(self, database):
db = self.ds.databases[database]
if self.ds.config("hash_urls") and db.hash:
return "/{}-{}".format(
database, db.hash[:HASH_LENGTH]
)
return "/{}-{}".format(database, db.hash[:HASH_LENGTH])
else:
return "/{}".format(database)
def database_color(self, database):
return 'ff0000'
return "ff0000"
def render(self, templates, **context):
template = self.ds.jinja_env.select_template(templates)
@ -95,7 +101,7 @@ class RenderMixin(HTTPMethodView):
database=context.get("database"),
table=context.get("table"),
view_name=self.name,
datasette=self.ds
datasette=self.ds,
):
body_scripts.append(jinja2.Markup(script))
return response.html(
@ -116,14 +122,14 @@ class RenderMixin(HTTPMethodView):
"format_bytes": format_bytes,
"database_url": self.database_url,
"database_color": self.database_color,
}
},
}
)
)
class BaseView(RenderMixin):
name = ''
name = ""
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
def __init__(self, datasette):
@ -171,32 +177,30 @@ class BaseView(RenderMixin):
expected = "000"
if db.hash is not None:
expected = db.hash[:HASH_LENGTH]
correct_hash_provided = (expected == hash)
correct_hash_provided = expected == hash
if not correct_hash_provided:
if "table_and_format" in kwargs:
async def async_table_exists(t):
return await self.ds.table_exists(name, t)
table, _format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"]
),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
kwargs["table"] = table
if _format:
kwargs["as_format"] = ".{}".format(_format)
elif "table" in kwargs:
kwargs["table"] = urllib.parse.unquote_plus(
kwargs["table"]
)
kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
should_redirect = "/{}-{}".format(name, expected)
if "table" in kwargs:
should_redirect += "/" + urllib.parse.quote_plus(
kwargs["table"]
)
should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
if "pk_path" in kwargs:
should_redirect += "/" + kwargs["pk_path"]
if "as_format" in kwargs:
@ -219,7 +223,9 @@ class BaseView(RenderMixin):
if should_redirect:
return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(request, database, hash, correct_hash_provided, **kwargs)
return await self.view_get(
request, database, hash, correct_hash_provided, **kwargs
)
async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream")
@ -228,9 +234,7 @@ class BaseView(RenderMixin):
if not self.ds.config("allow_csv_stream"):
raise DatasetteError("CSV streaming is disabled", status=400)
if request.args.get("_next"):
raise DatasetteError(
"_next not allowed for CSV streaming", status=400
)
raise DatasetteError("_next not allowed for CSV streaming", status=400)
kwargs["_size"] = "max"
# Fetch the first page
try:
@ -271,9 +275,7 @@ class BaseView(RenderMixin):
if next:
kwargs["_next"] = next
if not first:
data, _, _ = await self.data(
request, database, hash, **kwargs
)
data, _, _ = await self.data(request, database, hash, **kwargs)
if first:
writer.writerow(headings)
first = False
@ -293,7 +295,7 @@ class BaseView(RenderMixin):
new_row.append(cell)
writer.writerow(new_row)
except Exception as e:
print('caught this', e)
print("caught this", e)
r.write(str(e))
return
@ -304,15 +306,11 @@ class BaseView(RenderMixin):
if request.args.get("_dl", None):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', database)
kwargs.get("table", database)
)
headers["Content-Disposition"] = disposition
return response.stream(
stream_fn,
headers=headers,
content_type=content_type
)
return response.stream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL
@ -325,22 +323,20 @@ class BaseView(RenderMixin):
if not _format:
_format = (args.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in args:
async def async_table_exists(t):
return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
args["table_and_format"]
),
table_and_format=urllib.parse.unquote_plus(args["table_and_format"]),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
_format = _format or _ext_format
args["table"] = table
del args["table_and_format"]
elif "table" in args:
args["table"] = urllib.parse.unquote_plus(
args["table"]
)
args["table"] = urllib.parse.unquote_plus(args["table"])
return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
@ -351,7 +347,7 @@ class BaseView(RenderMixin):
if _format is None:
# HTML views default to expanding all foriegn key labels
kwargs['default_labels'] = True
kwargs["default_labels"] = True
extra_template_data = {}
start = time.time()
@ -367,11 +363,16 @@ class BaseView(RenderMixin):
else:
data, extra_template_data, templates = response_or_template_contexts
except InterruptedError:
raise DatasetteError("""
raise DatasetteError(
"""
SQL query took too long. The time limit is controlled by the
<a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
configuration option.
""", title="SQL Interrupted", status=400, messagge_is_html=True)
""",
title="SQL Interrupted",
status=400,
messagge_is_html=True,
)
except (sqlite3.OperationalError, InvalidSql) as e:
raise DatasetteError(str(e), title="Invalid SQL", status=400)
@ -408,14 +409,14 @@ class BaseView(RenderMixin):
raise NotFound("No data")
response_args = {
'content_type': result.get('content_type', 'text/plain'),
'status': result.get('status_code', 200)
"content_type": result.get("content_type", "text/plain"),
"status": result.get("status_code", 200),
}
if type(result.get('body')) == bytes:
response_args['body_bytes'] = result.get('body')
if type(result.get("body")) == bytes:
response_args["body_bytes"] = result.get("body")
else:
response_args['body'] = result.get('body')
response_args["body"] = result.get("body")
r = response.HTTPResponse(**response_args)
else:
@ -431,14 +432,12 @@ class BaseView(RenderMixin):
url_labels_extra = {"_labels": "on"}
renderers = {
key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys()
}
url_csv_args = {
"_size": "max",
**url_labels_extra
key: path_with_format(request, key, {**url_labels_extra})
for key in self.ds.renderers.keys()
}
url_csv_args = {"_size": "max", **url_labels_extra}
url_csv = path_with_format(request, "csv", url_csv_args)
url_csv_path = url_csv.split('?')[0]
url_csv_path = url_csv.split("?")[0]
context = {
**data,
**extras,
@ -450,10 +449,11 @@ class BaseView(RenderMixin):
(key, value)
for key, value in urllib.parse.parse_qsl(request.query_string)
if key not in ("_labels", "_facet", "_size")
] + [("_size", "max")],
]
+ [("_size", "max")],
"datasette_version": __version__,
"config": self.ds.config_dict(),
}
},
}
if "metadata" not in context:
context["metadata"] = self.ds.metadata
@ -474,9 +474,9 @@ class BaseView(RenderMixin):
if self.ds.cache_headers and response.status == 200:
ttl = int(ttl)
if ttl == 0:
ttl_header = 'no-cache'
ttl_header = "no-cache"
else:
ttl_header = 'max-age={}'.format(ttl)
ttl_header = "max-age={}".format(ttl)
response.headers["Cache-Control"] = ttl_header
response.headers["Referrer-Policy"] = "no-referrer"
if self.ds.cors:
@ -484,8 +484,15 @@ class BaseView(RenderMixin):
return response
async def custom_sql(
self, request, database, hash, sql, editable=True, canned_query=None,
metadata=None, _size=None
self,
request,
database,
hash,
sql,
editable=True,
canned_query=None,
metadata=None,
_size=None,
):
params = request.raw_args
if "sql" in params:
@ -565,10 +572,14 @@ class BaseView(RenderMixin):
"hide_sql": "_hide_sql" in params,
}
return {
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
}, extra_template, templates
return (
{
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
},
extra_template,
templates,
)