Compare commits

...

12 commits

Author SHA1 Message Date
Simon Willison
45c83b4c35 Respect --cors for error pages, closes #453 2019-05-05 07:59:45 -04:00
Simon Willison
9683aeb239 Do not attempt 'import black' for Python 3.5 2019-05-03 22:05:42 -04:00
Simon Willison
fcec3badd8 Tweak version check
Because 3.5.1 > 3.5
2019-05-03 22:01:32 -04:00
Simon Willison
d88d015581 Conditionally install black 2019-05-03 21:58:27 -04:00
Simon Willison
acb2c3ab24 Only install black if python > 3.5 2019-05-03 21:53:59 -04:00
Simon Willison
fd8d377a34 More robust re-ordering of tests 2019-05-03 21:48:24 -04:00
Simon Willison
b968df2033 Only run black on Python 3.6 or higher 2019-05-03 21:43:30 -04:00
Simon Willison
ade6bae472 Ensure test_black.py executes first under pytest
Otherwise it throws an exception because it gets disrupted by the
asyncio activity in the other tests.
2019-05-03 21:25:49 -04:00
Simon Willison
f9193e7a18 Unit test enforcing black formatting
Only runs for Python 3.6 at the moment.

See https://github.com/python/black/issues/425
2019-05-03 21:06:47 -04:00
Simon Willison
9e054f5a84 Black against setup.py 2019-05-03 20:51:16 -04:00
Simon Willison
4935cc5c06 Neater formatting of config description text 2019-05-03 18:01:21 -04:00
Simon Willison
6300d3e269 Apply black to everything
I ran this:

    black datasette tests
2019-05-03 17:56:52 -04:00
31 changed files with 2779 additions and 2704 deletions

View file

@ -1,4 +1,3 @@
# This file helps to compute a version number in source trees obtained from # This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag # git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build # feature). Distribution tarballs (built by setup.py sdist) and build
@ -58,17 +57,18 @@ HANDLERS = {}
def register_vcs_handler(vcs, method): # decorator def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS.""" """Decorator to mark a method as the handler for a particular VCS."""
def decorate(f): def decorate(f):
"""Store f in HANDLERS[vcs][method].""" """Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS: if vcs not in HANDLERS:
HANDLERS[vcs] = {} HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f HANDLERS[vcs][method] = f
return f return f
return decorate return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
env=None):
"""Call the given command(s).""" """Call the given command(s)."""
assert isinstance(commands, list) assert isinstance(commands, list)
p = None p = None
@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try: try:
dispcmd = str([c] + args) dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git # remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env, p = subprocess.Popen(
[c] + args,
cwd=cwd,
env=env,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr stderr=(subprocess.PIPE if hide_stderr else None),
else None)) )
break break
except EnvironmentError: except EnvironmentError:
e = sys.exc_info()[1] e = sys.exc_info()[1]
@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3): for i in range(3):
dirname = os.path.basename(root) dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix): if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):], return {
"version": dirname[len(parentdir_prefix) :],
"full-revisionid": None, "full-revisionid": None,
"dirty": False, "error": None, "date": None} "dirty": False,
"error": None,
"date": None,
}
else: else:
rootdirs.append(root) rootdirs.append(root)
root = os.path.dirname(root) # up a level root = os.path.dirname(root) # up a level
if verbose: if verbose:
print("Tried directories %s but none started with prefix %s" % print(
(str(rootdirs), parentdir_prefix)) "Tried directories %s but none started with prefix %s"
% (str(rootdirs), parentdir_prefix)
)
raise NotThisMethod("rootdir doesn't start with parentdir_prefix") raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we # between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and # filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master". # "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)]) tags = set([r for r in refs if re.search(r"\d", r)])
if verbose: if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags)) print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose: if verbose:
@ -201,16 +210,23 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
r = ref[len(tag_prefix) :] r = ref[len(tag_prefix) :]
if verbose: if verbose:
print("picking %s" % r) print("picking %s" % r)
return {"version": r, return {
"version": r,
"full-revisionid": keywords["full"].strip(), "full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None, "dirty": False,
"date": date} "error": None,
"date": date,
}
# no suitable tags, so version is "0+unknown", but full hex is still there # no suitable tags, so version is "0+unknown", but full hex is still there
if verbose: if verbose:
print("no suitable tags, using unknown + full revision id") print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown", return {
"version": "0+unknown",
"full-revisionid": keywords["full"].strip(), "full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None} "dirty": False,
"error": "no suitable tags",
"date": None,
}
@register_vcs_handler("git", "pieces_from_vcs") @register_vcs_handler("git", "pieces_from_vcs")
@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32": if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"] GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
hide_stderr=True)
if rc != 0: if rc != 0:
if verbose: if verbose:
print("Directory %s not under git control" % root) print("Directory %s not under git control" % root)
@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM) # if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", describe_out, rc = run_command(
"--always", "--long", GITS,
"--match", "%s*" % tag_prefix], [
cwd=root) "describe",
"--tags",
"--dirty",
"--always",
"--long",
"--match",
"%s*" % tag_prefix,
],
cwd=root,
)
# --long was added in git-1.5.5 # --long was added in git-1.5.5
if describe_out is None: if describe_out is None:
raise NotThisMethod("'git describe' failed") raise NotThisMethod("'git describe' failed")
@ -266,11 +290,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if "-" in git_describe: if "-" in git_describe:
# TAG-NUM-gHEX # TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo: if not mo:
# unparseable. Maybe git-describe is misbehaving? # unparseable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'" pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
% describe_out)
return pieces return pieces
# tag # tag
@ -279,8 +302,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose: if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'" fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix)) print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
% (full_tag, tag_prefix)) full_tag,
tag_prefix,
)
return pieces return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix) :] pieces["closest-tag"] = full_tag[len(tag_prefix) :]
@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else: else:
# HEX: no tags # HEX: no tags
pieces["closest-tag"] = None pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
cwd=root)
pieces["distance"] = int(count_out) # total number of commits pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords() # commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[
cwd=root)[0].strip() 0
].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces return pieces
@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered += ".dirty" rendered += ".dirty"
else: else:
# exception #1 # exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"], rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
pieces["short"])
if pieces["dirty"]: if pieces["dirty"]:
rendered += ".dirty" rendered += ".dirty"
return rendered return rendered
@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def render(pieces, style): def render(pieces, style):
"""Render the given version pieces into the requested style.""" """Render the given version pieces into the requested style."""
if pieces["error"]: if pieces["error"]:
return {"version": "unknown", return {
"version": "unknown",
"full-revisionid": pieces.get("long"), "full-revisionid": pieces.get("long"),
"dirty": None, "dirty": None,
"error": pieces["error"], "error": pieces["error"],
"date": None} "date": None,
}
if not style or style == "default": if not style or style == "default":
style = "pep440" # the default style = "pep440" # the default
@ -469,9 +495,13 @@ def render(pieces, style):
else: else:
raise ValueError("unknown style '%s'" % style) raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"], return {
"dirty": pieces["dirty"], "error": None, "version": rendered,
"date": pieces.get("date")} "full-revisionid": pieces["long"],
"dirty": pieces["dirty"],
"error": None,
"date": pieces.get("date"),
}
def get_versions(): def get_versions():
@ -485,8 +515,7 @@ def get_versions():
verbose = cfg.verbose verbose = cfg.verbose
try: try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
verbose)
except NotThisMethod: except NotThisMethod:
pass pass
@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source # versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert # tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__. # this to find the root from __file__.
for i in cfg.versionfile_source.split('/'): for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root) root = os.path.dirname(root)
except NameError: except NameError:
return {"version": "0+unknown", "full-revisionid": None, return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None, "dirty": None,
"error": "unable to find root of source tree", "error": "unable to find root of source tree",
"date": None} "date": None,
}
try: try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@ -515,6 +547,10 @@ def get_versions():
except NotThisMethod: except NotThisMethod:
pass pass
return {"version": "0+unknown", "full-revisionid": None, return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None, "dirty": None,
"error": "unable to compute version", "date": None} "error": "unable to compute version",
"date": None,
}

View file

@ -17,10 +17,7 @@ from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound from sanic.exceptions import InvalidUsage, NotFound
from .views.base import ( from .views.base import DatasetteError, ureg
DatasetteError,
ureg
)
from .views.database import DatabaseDownload, DatabaseView from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView from .views.index import IndexView
from .views.special import JsonDataView from .views.special import JsonDataView
@ -39,7 +36,7 @@ from .utils import (
sqlite3, sqlite3,
sqlite_timelimit, sqlite_timelimit,
table_columns, table_columns,
to_css_class to_css_class,
) )
from .inspect import inspect_hash, inspect_views, inspect_tables from .inspect import inspect_hash, inspect_views, inspect_tables
from .tracer import capture_traces, trace from .tracer import capture_traces, trace
@ -51,72 +48,85 @@ app_root = Path(__file__).parent.parent
connections = threading.local() connections = threading.local()
MEMORY = object() MEMORY = object()
ConfigOption = collections.namedtuple( ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
"ConfigOption", ("name", "default", "help")
)
CONFIG_OPTIONS = ( CONFIG_OPTIONS = (
ConfigOption("default_page_size", 100, """ ConfigOption("default_page_size", 100, "Default page size for the table view"),
Default page size for the table view ConfigOption(
""".strip()), "max_returned_rows",
ConfigOption("max_returned_rows", 1000, """ 1000,
Maximum rows that can be returned from a table or custom query "Maximum rows that can be returned from a table or custom query",
""".strip()), ),
ConfigOption("num_sql_threads", 3, """ ConfigOption(
Number of threads in the thread pool for executing SQLite queries "num_sql_threads",
""".strip()), 3,
ConfigOption("sql_time_limit_ms", 1000, """ "Number of threads in the thread pool for executing SQLite queries",
Time limit for a SQL query in milliseconds ),
""".strip()), ConfigOption(
ConfigOption("default_facet_size", 30, """ "sql_time_limit_ms", 1000, "Time limit for a SQL query in milliseconds"
Number of values to return for requested facets ),
""".strip()), ConfigOption(
ConfigOption("facet_time_limit_ms", 200, """ "default_facet_size", 30, "Number of values to return for requested facets"
Time limit for calculating a requested facet ),
""".strip()), ConfigOption(
ConfigOption("facet_suggest_time_limit_ms", 50, """ "facet_time_limit_ms", 200, "Time limit for calculating a requested facet"
Time limit for calculating a suggested facet ),
""".strip()), ConfigOption(
ConfigOption("hash_urls", False, """ "facet_suggest_time_limit_ms",
Include DB file contents hash in URLs, for far-future caching 50,
""".strip()), "Time limit for calculating a suggested facet",
ConfigOption("allow_facet", True, """ ),
Allow users to specify columns to facet using ?_facet= parameter ConfigOption(
""".strip()), "hash_urls",
ConfigOption("allow_download", True, """ False,
Allow users to download the original SQLite database files "Include DB file contents hash in URLs, for far-future caching",
""".strip()), ),
ConfigOption("suggest_facets", True, """ ConfigOption(
Calculate and display suggested facets "allow_facet",
""".strip()), True,
ConfigOption("allow_sql", True, """ "Allow users to specify columns to facet using ?_facet= parameter",
Allow arbitrary SQL queries via ?sql= parameter ),
""".strip()), ConfigOption(
ConfigOption("default_cache_ttl", 5, """ "allow_download",
Default HTTP cache TTL (used in Cache-Control: max-age= header) True,
""".strip()), "Allow users to download the original SQLite database files",
ConfigOption("default_cache_ttl_hashed", 365 * 24 * 60 * 60, """ ),
Default HTTP cache TTL for hashed URL pages ConfigOption("suggest_facets", True, "Calculate and display suggested facets"),
""".strip()), ConfigOption("allow_sql", True, "Allow arbitrary SQL queries via ?sql= parameter"),
ConfigOption("cache_size_kb", 0, """ ConfigOption(
SQLite cache size in KB (0 == use SQLite default) "default_cache_ttl",
""".strip()), 5,
ConfigOption("allow_csv_stream", True, """ "Default HTTP cache TTL (used in Cache-Control: max-age= header)",
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows) ),
""".strip()), ConfigOption(
ConfigOption("max_csv_mb", 100, """ "default_cache_ttl_hashed",
Maximum size allowed for CSV export in MB - set 0 to disable this limit 365 * 24 * 60 * 60,
""".strip()), "Default HTTP cache TTL for hashed URL pages",
ConfigOption("truncate_cells_html", 2048, """ ),
Truncate cells longer than this in HTML table view - set 0 to disable ConfigOption(
""".strip()), "cache_size_kb", 0, "SQLite cache size in KB (0 == use SQLite default)"
ConfigOption("force_https_urls", False, """ ),
Force URLs in API output to always use https:// protocol ConfigOption(
""".strip()), "allow_csv_stream",
True,
"Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)",
),
ConfigOption(
"max_csv_mb",
100,
"Maximum size allowed for CSV export in MB - set 0 to disable this limit",
),
ConfigOption(
"truncate_cells_html",
2048,
"Truncate cells longer than this in HTML table view - set 0 to disable",
),
ConfigOption(
"force_https_urls",
False,
"Force URLs in API output to always use https:// protocol",
),
) )
DEFAULT_CONFIG = { DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
option.name: option.default
for option in CONFIG_OPTIONS
}
async def favicon(request): async def favicon(request):
@ -151,11 +161,13 @@ class ConnectedDatabase:
counts = {} counts = {}
for table in await self.table_names(): for table in await self.table_names():
try: try:
table_count = (await self.ds.execute( table_count = (
await self.ds.execute(
self.name, self.name,
"select count(*) from [{}]".format(table), "select count(*) from [{}]".format(table),
custom_time_limit=limit, custom_time_limit=limit,
)).rows[0][0] )
).rows[0][0]
counts[table] = table_count counts[table] = table_count
except InterruptedError: except InterruptedError:
counts[table] = None counts[table] = None
@ -175,18 +187,26 @@ class ConnectedDatabase:
return Path(self.path).stem return Path(self.path).stem
async def table_names(self): async def table_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='table'") results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows] return [r[0] for r in results.rows]
async def hidden_table_names(self): async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables # Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [r[0] for r in ( hidden_tables = [
await self.ds.execute(self.name, """ r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master select name from sqlite_master
where rootpage = 0 where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%' and sql like '%VIRTUAL TABLE%USING FTS%'
""") """,
).rows] )
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread( has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite self.name, detect_spatialite
) )
@ -205,18 +225,23 @@ class ConnectedDatabase:
] + [ ] + [
r[0] r[0]
for r in ( for r in (
await self.ds.execute(self.name, """ await self.ds.execute(
self.name,
"""
select name from sqlite_master select name from sqlite_master
where name like "idx_%" where name like "idx_%"
and type = "table" and type = "table"
""") """,
)
).rows ).rows
] ]
# Add any from metadata.json # Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name) db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata: if "tables" in db_metadata:
hidden_tables += [ hidden_tables += [
t for t in db_metadata["tables"] if db_metadata["tables"][t].get("hidden") t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
] ]
# Also mark as hidden any tables which start with the name of a hidden table # Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
@ -229,7 +254,9 @@ class ConnectedDatabase:
return hidden_tables return hidden_tables
async def view_names(self): async def view_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='view'") results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows] return [r[0] for r in results.rows]
def __repr__(self): def __repr__(self):
@ -245,13 +272,10 @@ class ConnectedDatabase:
tags_str = "" tags_str = ""
if tags: if tags:
tags_str = " ({})".format(", ".join(tags)) tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format( return "<ConnectedDatabase: {}{}>".format(self.name, tags_str)
self.name, tags_str
)
class Datasette: class Datasette:
def __init__( def __init__(
self, self,
files, files,
@ -283,7 +307,9 @@ class Datasette:
path = None path = None
is_memory = True is_memory = True
is_mutable = path not in self.immutables is_mutable = path not in self.immutables
db = ConnectedDatabase(self, path, is_mutable=is_mutable, is_memory=is_memory) db = ConnectedDatabase(
self, path, is_mutable=is_mutable, is_memory=is_memory
)
if db.name in self.databases: if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name)) raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db self.databases[db.name] = db
@ -322,26 +348,24 @@ class Datasette:
def config_dict(self): def config_dict(self):
# Returns a fully resolved config dictionary, useful for templates # Returns a fully resolved config dictionary, useful for templates
return { return {option.name: self.config(option.name) for option in CONFIG_OPTIONS}
option.name: self.config(option.name)
for option in CONFIG_OPTIONS
}
def metadata(self, key=None, database=None, table=None, fallback=True): def metadata(self, key=None, database=None, table=None, fallback=True):
""" """
Looks up metadata, cascading backwards from specified level. Looks up metadata, cascading backwards from specified level.
Returns None if metadata value is not found. Returns None if metadata value is not found.
""" """
assert not (database is None and table is not None), \ assert not (
"Cannot call metadata() with table= specified but not database=" database is None and table is not None
), "Cannot call metadata() with table= specified but not database="
databases = self._metadata.get("databases") or {} databases = self._metadata.get("databases") or {}
search_list = [] search_list = []
if database is not None: if database is not None:
search_list.append(databases.get(database) or {}) search_list.append(databases.get(database) or {})
if table is not None: if table is not None:
table_metadata = ( table_metadata = ((databases.get(database) or {}).get("tables") or {}).get(
(databases.get(database) or {}).get("tables") or {} table
).get(table) or {} ) or {}
search_list.insert(0, table_metadata) search_list.insert(0, table_metadata)
search_list.append(self._metadata) search_list.append(self._metadata)
if not fallback: if not fallback:
@ -359,9 +383,7 @@ class Datasette:
m.update(item) m.update(item)
return m return m
def plugin_config( def plugin_config(self, plugin_name, database=None, table=None, fallback=True):
self, plugin_name, database=None, table=None, fallback=True
):
"Return config for plugin, falling back from specified database/table" "Return config for plugin, falling back from specified database/table"
plugins = self.metadata( plugins = self.metadata(
"plugins", database=database, table=table, fallback=fallback "plugins", database=database, table=table, fallback=fallback
@ -373,29 +395,19 @@ class Datasette:
def app_css_hash(self): def app_css_hash(self):
if not hasattr(self, "_app_css_hash"): if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1( self._app_css_hash = hashlib.sha1(
open( open(os.path.join(str(app_root), "datasette/static/app.css"))
os.path.join(str(app_root), "datasette/static/app.css") .read()
).read().encode( .encode("utf8")
"utf8" ).hexdigest()[:6]
)
).hexdigest()[
:6
]
return self._app_css_hash return self._app_css_hash
def get_canned_queries(self, database_name): def get_canned_queries(self, database_name):
queries = self.metadata( queries = self.metadata("queries", database=database_name, fallback=False) or {}
"queries", database=database_name, fallback=False
) or {}
names = queries.keys() names = queries.keys()
return [ return [self.get_canned_query(database_name, name) for name in names]
self.get_canned_query(database_name, name) for name in names
]
def get_canned_query(self, database_name, query_name): def get_canned_query(self, database_name, query_name):
queries = self.metadata( queries = self.metadata("queries", database=database_name, fallback=False) or {}
"queries", database=database_name, fallback=False
) or {}
query = queries.get(query_name) query = queries.get(query_name)
if query: if query:
if not isinstance(query, dict): if not isinstance(query, dict):
@ -407,7 +419,7 @@ class Datasette:
table_definition_rows = list( table_definition_rows = list(
await self.execute( await self.execute(
database_name, database_name,
'select sql from sqlite_master where name = :n and type=:t', "select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_}, {"n": table, "t": type_},
) )
) )
@ -416,21 +428,19 @@ class Datasette:
return table_definition_rows[0][0] return table_definition_rows[0][0]
def get_view_definition(self, database_name, view): def get_view_definition(self, database_name, view):
return self.get_table_definition(database_name, view, 'view') return self.get_table_definition(database_name, view, "view")
def update_with_inherited_metadata(self, metadata): def update_with_inherited_metadata(self, metadata):
# Fills in source/license with defaults, if available # Fills in source/license with defaults, if available
metadata.update( metadata.update(
{ {
"source": metadata.get("source") or self.metadata("source"), "source": metadata.get("source") or self.metadata("source"),
"source_url": metadata.get("source_url") "source_url": metadata.get("source_url") or self.metadata("source_url"),
or self.metadata("source_url"),
"license": metadata.get("license") or self.metadata("license"), "license": metadata.get("license") or self.metadata("license"),
"license_url": metadata.get("license_url") "license_url": metadata.get("license_url")
or self.metadata("license_url"), or self.metadata("license_url"),
"about": metadata.get("about") or self.metadata("about"), "about": metadata.get("about") or self.metadata("about"),
"about_url": metadata.get("about_url") "about_url": metadata.get("about_url") or self.metadata("about_url"),
or self.metadata("about_url"),
} }
) )
@ -444,7 +454,7 @@ class Datasette:
for extension in self.sqlite_extensions: for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension)) conn.execute("SELECT load_extension('{}')".format(extension))
if self.config("cache_size_kb"): if self.config("cache_size_kb"):
conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb"))) conn.execute("PRAGMA cache_size=-{}".format(self.config("cache_size_kb")))
# pylint: disable=no-member # pylint: disable=no-member
pm.hook.prepare_connection(conn=conn) pm.hook.prepare_connection(conn=conn)
@ -452,7 +462,7 @@ class Datasette:
results = await self.execute( results = await self.execute(
database, database,
"select 1 from sqlite_master where type='table' and name=?", "select 1 from sqlite_master where type='table' and name=?",
params=(table,) params=(table,),
) )
return bool(results.rows) return bool(results.rows)
@ -463,32 +473,28 @@ class Datasette:
# Find the foreign_key for this column # Find the foreign_key for this column
try: try:
fk = [ fk = [
foreign_key for foreign_key in foreign_keys foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column if foreign_key["column"] == column
][0] ][0]
except IndexError: except IndexError:
return {} return {}
label_column = await self.label_column_for_table(database, fk["other_table"]) label_column = await self.label_column_for_table(database, fk["other_table"])
if not label_column: if not label_column:
return { return {(fk["column"], value): str(value) for value in values}
(fk["column"], value): str(value)
for value in values
}
labeled_fks = {} labeled_fks = {}
sql = ''' sql = """
select {other_column}, {label_column} select {other_column}, {label_column}
from {other_table} from {other_table}
where {other_column} in ({placeholders}) where {other_column} in ({placeholders})
'''.format( """.format(
other_column=escape_sqlite(fk["other_column"]), other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column), label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]), other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))), placeholders=", ".join(["?"] * len(set(values))),
) )
try: try:
results = await self.execute( results = await self.execute(database, sql, list(set(values)))
database, sql, list(set(values))
)
except InterruptedError: except InterruptedError:
pass pass
else: else:
@ -532,10 +538,12 @@ class Datasette:
"file": str(path), "file": str(path),
"size": path.stat().st_size, "size": path.stat().st_size,
"views": inspect_views(conn), "views": inspect_views(conn),
"tables": inspect_tables(conn, (self.metadata("databases") or {}).get(name, {})) "tables": inspect_tables(
conn, (self.metadata("databases") or {}).get(name, {})
),
} }
except sqlite3.OperationalError as e: except sqlite3.OperationalError as e:
if (e.args[0] == 'no such module: VirtualSpatialIndex'): if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError( raise click.UsageError(
"It looks like you're trying to load a SpatiaLite" "It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module." " database without first loading the SpatiaLite module."
@ -582,7 +590,8 @@ class Datasette:
datasette_version["note"] = self.version_note datasette_version["note"] = self.version_note
return { return {
"python": { "python": {
"version": ".".join(map(str, sys.version_info[:3])), "full": sys.version "version": ".".join(map(str, sys.version_info[:3])),
"full": sys.version,
}, },
"datasette": datasette_version, "datasette": datasette_version,
"sqlite": { "sqlite": {
@ -611,10 +620,11 @@ class Datasette:
def table_metadata(self, database, table): def table_metadata(self, database, table):
"Fetch table-specific metadata." "Fetch table-specific metadata."
return (self.metadata("databases") or {}).get(database, {}).get( return (
"tables", {} (self.metadata("databases") or {})
).get( .get(database, {})
table, {} .get("tables", {})
.get(table, {})
) )
async def table_columns(self, db_name, table): async def table_columns(self, db_name, table):
@ -628,16 +638,12 @@ class Datasette:
) )
async def label_column_for_table(self, db_name, table): async def label_column_for_table(self, db_name, table):
explicit_label_column = ( explicit_label_column = self.table_metadata(db_name, table).get("label_column")
self.table_metadata(
db_name, table
).get("label_column")
)
if explicit_label_column: if explicit_label_column:
return explicit_label_column return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one # If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.table_columns(db_name, table) column_names = await self.table_columns(db_name, table)
if (column_names and len(column_names) == 2 and "id" in column_names): if column_names and len(column_names) == 2 and "id" in column_names:
return [c for c in column_names if c != "id"][0] return [c for c in column_names if c != "id"][0]
# Couldn't find a label: # Couldn't find a label:
return None return None
@ -664,9 +670,7 @@ class Datasette:
setattr(connections, db_name, conn) setattr(connections, db_name, conn)
return fn(conn) return fn(conn)
return await asyncio.get_event_loop().run_in_executor( return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
self.executor, in_thread
)
async def execute( async def execute(
self, self,
@ -701,7 +705,7 @@ class Datasette:
rows = cursor.fetchall() rows = cursor.fetchall()
truncated = False truncated = False
except sqlite3.OperationalError as e: except sqlite3.OperationalError as e:
if e.args == ('interrupted',): if e.args == ("interrupted",):
raise InterruptedError(e, sql, params) raise InterruptedError(e, sql, params)
if log_sql_errors: if log_sql_errors:
print( print(
@ -726,7 +730,7 @@ class Datasette:
def register_renderers(self): def register_renderers(self):
""" Register output renderers which output data in custom formats. """ """ Register output renderers which output data in custom formats. """
# Built-in renderers # Built-in renderers
self.renderers['json'] = json_renderer self.renderers["json"] = json_renderer
# Hooks # Hooks
hook_renderers = [] hook_renderers = []
@ -737,19 +741,22 @@ class Datasette:
hook_renderers.append(hook) hook_renderers.append(hook)
for renderer in hook_renderers: for renderer in hook_renderers:
self.renderers[renderer['extension']] = renderer['callback'] self.renderers[renderer["extension"]] = renderer["callback"]
def app(self): def app(self):
class TracingSanic(Sanic): class TracingSanic(Sanic):
async def handle_request(self, request, write_callback, stream_callback): async def handle_request(self, request, write_callback, stream_callback):
if request.args.get("_trace"): if request.args.get("_trace"):
request["traces"] = [] request["traces"] = []
request["trace_start"] = time.time() request["trace_start"] = time.time()
with capture_traces(request["traces"]): with capture_traces(request["traces"]):
res = await super().handle_request(request, write_callback, stream_callback) res = await super().handle_request(
request, write_callback, stream_callback
)
else: else:
res = await super().handle_request(request, write_callback, stream_callback) res = await super().handle_request(
request, write_callback, stream_callback
)
return res return res
app = TracingSanic(__name__) app = TracingSanic(__name__)
@ -822,15 +829,16 @@ class Datasette:
) )
app.add_route( app.add_route(
DatabaseView.as_view(self), DatabaseView.as_view(self),
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>" r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
) )
app.add_route( app.add_route(
TableView.as_view(self), TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>",
) )
app.add_route( app.add_route(
RowView.as_view(self), RowView.as_view(self),
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:(" + renderer_regex + r")?$>", r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
+ renderer_regex
+ r")?$>",
) )
self.register_custom_units() self.register_custom_units()
@ -852,7 +860,7 @@ class Datasette:
"duration": time.time() - request["trace_start"], "duration": time.time() - request["trace_start"],
"queries": request["traces"], "queries": request["traces"],
} }
if "text/html" in response.content_type and b'</body>' in response.body: if "text/html" in response.content_type and b"</body>" in response.body:
extra = json.dumps(traces, indent=2) extra = json.dumps(traces, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8") extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
response.body = response.body.replace(b"</body>", extra_html) response.body = response.body.replace(b"</body>", extra_html)
@ -897,11 +905,14 @@ class Datasette:
{"ok": False, "error": message, "status": status, "title": title} {"ok": False, "error": message, "status": status, "title": title}
) )
if request is not None and request.path.split("?")[0].endswith(".json"): if request is not None and request.path.split("?")[0].endswith(".json"):
return response.json(info, status=status) r = response.json(info, status=status)
else: else:
template = self.jinja_env.select_template(templates) template = self.jinja_env.select_template(templates)
return response.html(template.render(info), status=status) r = response.html(template.render(info), status=status)
if self.cors:
r.headers["Access-Control-Allow-Origin"] = "*"
return r
# First time server starts up, calculate table counts for immutable databases # First time server starts up, calculate table counts for immutable databases
@app.listener("before_server_start") @app.listener("before_server_start")

View file

@ -20,16 +20,14 @@ class Config(click.ParamType):
def convert(self, config, param, ctx): def convert(self, config, param, ctx):
if ":" not in config: if ":" not in config:
self.fail( self.fail('"{}" should be name:value'.format(config), param, ctx)
'"{}" should be name:value'.format(config), param, ctx
)
return return
name, value = config.split(":") name, value = config.split(":")
if name not in DEFAULT_CONFIG: if name not in DEFAULT_CONFIG:
self.fail( self.fail(
"{} is not a valid option (--help-config to see all)".format( "{} is not a valid option (--help-config to see all)".format(name),
name param,
), param, ctx ctx,
) )
return return
# Type checking # Type checking
@ -44,14 +42,12 @@ class Config(click.ParamType):
return return
elif isinstance(default, int): elif isinstance(default, int):
if not value.isdigit(): if not value.isdigit():
self.fail( self.fail('"{}" should be an integer'.format(name), param, ctx)
'"{}" should be an integer'.format(name), param, ctx
)
return return
return name, int(value) return name, int(value)
else: else:
# Should never happen: # Should never happen:
self.fail('Invalid option') self.fail("Invalid option")
@click.group(cls=DefaultGroup, default="serve", default_if_no_args=True) @click.group(cls=DefaultGroup, default="serve", default_if_no_args=True)
@ -204,13 +200,9 @@ def plugins(all, plugins_dir):
multiple=True, multiple=True,
) )
@click.option( @click.option(
"--install", "--install", help="Additional packages (e.g. plugins) to install", multiple=True
help="Additional packages (e.g. plugins) to install",
multiple=True,
)
@click.option(
"--spatialite", is_flag=True, help="Enable SpatialLite extension"
) )
@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension")
@click.option("--version-note", help="Additional note to show on /-/versions") @click.option("--version-note", help="Additional note to show on /-/versions")
@click.option("--title", help="Title for metadata") @click.option("--title", help="Title for metadata")
@click.option("--license", help="License label for metadata") @click.option("--license", help="License label for metadata")
@ -322,9 +314,7 @@ def package(
help="mountpoint:path-to-directory for serving static files", help="mountpoint:path-to-directory for serving static files",
multiple=True, multiple=True,
) )
@click.option( @click.option("--memory", is_flag=True, help="Make :memory: database available")
"--memory", is_flag=True, help="Make :memory: database available"
)
@click.option( @click.option(
"--config", "--config",
type=Config(), type=Config(),
@ -332,11 +322,7 @@ def package(
multiple=True, multiple=True,
) )
@click.option("--version-note", help="Additional note to show on /-/versions") @click.option("--version-note", help="Additional note to show on /-/versions")
@click.option( @click.option("--help-config", is_flag=True, help="Show available config options")
"--help-config",
is_flag=True,
help="Show available config options",
)
def serve( def serve(
files, files,
immutable, immutable,
@ -360,12 +346,12 @@ def serve(
if help_config: if help_config:
formatter = formatting.HelpFormatter() formatter = formatting.HelpFormatter()
with formatter.section("Config options"): with formatter.section("Config options"):
formatter.write_dl([ formatter.write_dl(
(option.name, '{} (default={})'.format( [
option.help, option.default (option.name, "{} (default={})".format(option.help, option.default))
))
for option in CONFIG_OPTIONS for option in CONFIG_OPTIONS
]) ]
)
click.echo(formatter.getvalue()) click.echo(formatter.getvalue())
sys.exit(0) sys.exit(0)
if reload: if reload:
@ -384,7 +370,9 @@ def serve(
if metadata: if metadata:
metadata_data = json.loads(metadata.read()) metadata_data = json.loads(metadata.read())
click.echo("Serve! files={} (immutables={}) on port {}".format(files, immutable, port)) click.echo(
"Serve! files={} (immutables={}) on port {}".format(files, immutable, port)
)
ds = Datasette( ds = Datasette(
files, files,
immutables=immutable, immutables=immutable,

View file

@ -31,14 +31,15 @@ def load_facet_configs(request, table_metadata):
metadata_config = {"simple": metadata_config} metadata_config = {"simple": metadata_config}
else: else:
# This should have a single key and a single value # This should have a single key and a single value
assert len(metadata_config.values()) == 1, "Metadata config dicts should be {type: config}" assert (
len(metadata_config.values()) == 1
), "Metadata config dicts should be {type: config}"
type, metadata_config = metadata_config.items()[0] type, metadata_config = metadata_config.items()[0]
if isinstance(metadata_config, str): if isinstance(metadata_config, str):
metadata_config = {"simple": metadata_config} metadata_config = {"simple": metadata_config}
facet_configs.setdefault(type, []).append({ facet_configs.setdefault(type, []).append(
"source": "metadata", {"source": "metadata", "config": metadata_config}
"config": metadata_config )
})
qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True) qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
for key, values in qs_pairs.items(): for key, values in qs_pairs.items():
if key.startswith("_facet"): if key.startswith("_facet"):
@ -53,10 +54,9 @@ def load_facet_configs(request, table_metadata):
config = json.loads(value) config = json.loads(value)
else: else:
config = {"simple": value} config = {"simple": value}
facet_configs.setdefault(type, []).append({ facet_configs.setdefault(type, []).append(
"source": "request", {"source": "request", "config": config}
"config": config )
})
return facet_configs return facet_configs
@ -214,7 +214,9 @@ class ColumnFacet(Facet):
"name": column, "name": column,
"type": self.type, "type": self.type,
"hideable": source != "metadata", "hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet": column}), "toggle_url": path_with_removed_args(
self.request, {"_facet": column}
),
"results": facet_results_values, "results": facet_results_values,
"truncated": len(facet_rows_results) > facet_size, "truncated": len(facet_rows_results) > facet_size,
} }
@ -269,30 +271,31 @@ class ArrayFacet(Facet):
select distinct json_type({column}) select distinct json_type({column})
from ({sql}) from ({sql})
""".format( """.format(
column=escape_sqlite(column), column=escape_sqlite(column), sql=self.sql
sql=self.sql,
) )
try: try:
results = await self.ds.execute( results = await self.ds.execute(
self.database, suggested_facet_sql, self.params, self.database,
suggested_facet_sql,
self.params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config("facet_suggest_time_limit_ms"), custom_time_limit=self.ds.config("facet_suggest_time_limit_ms"),
log_sql_errors=False, log_sql_errors=False,
) )
types = tuple(r[0] for r in results.rows) types = tuple(r[0] for r in results.rows)
if types in ( if types in (("array",), ("array", None)):
("array",), suggested_facets.append(
("array", None) {
):
suggested_facets.append({
"name": column, "name": column,
"type": "array", "type": "array",
"toggle_url": self.ds.absolute_url( "toggle_url": self.ds.absolute_url(
self.request, path_with_added_args( self.request,
path_with_added_args(
self.request, {"_facet_array": column} self.request, {"_facet_array": column}
)
), ),
}) ),
}
)
except (InterruptedError, sqlite3.OperationalError): except (InterruptedError, sqlite3.OperationalError):
continue continue
return suggested_facets return suggested_facets
@ -314,13 +317,13 @@ class ArrayFacet(Facet):
) join json_each({col}) j ) join json_each({col}) j
group by j.value order by count desc limit {limit} group by j.value order by count desc limit {limit}
""".format( """.format(
col=escape_sqlite(column), col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
sql=self.sql,
limit=facet_size+1,
) )
try: try:
facet_rows_results = await self.ds.execute( facet_rows_results = await self.ds.execute(
self.database, facet_sql, self.params, self.database,
facet_sql,
self.params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config("facet_time_limit_ms"), custom_time_limit=self.ds.config("facet_time_limit_ms"),
) )
@ -330,7 +333,9 @@ class ArrayFacet(Facet):
"type": self.type, "type": self.type,
"results": facet_results_values, "results": facet_results_values,
"hideable": source != "metadata", "hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet_array": column}), "toggle_url": path_with_removed_args(
self.request, {"_facet_array": column}
),
"truncated": len(facet_rows_results) > facet_size, "truncated": len(facet_rows_results) > facet_size,
} }
facet_rows = facet_rows_results.rows[:facet_size] facet_rows = facet_rows_results.rows[:facet_size]
@ -346,13 +351,17 @@ class ArrayFacet(Facet):
toggle_path = path_with_added_args( toggle_path = path_with_added_args(
self.request, {"{}__arraycontains".format(column): value} self.request, {"{}__arraycontains".format(column): value}
) )
facet_results_values.append({ facet_results_values.append(
{
"value": value, "value": value,
"label": value, "label": value,
"count": row["count"], "count": row["count"],
"toggle_url": self.ds.absolute_url(self.request, toggle_path), "toggle_url": self.ds.absolute_url(
self.request, toggle_path
),
"selected": selected, "selected": selected,
}) }
)
except InterruptedError: except InterruptedError:
facets_timed_out.append(column) facets_timed_out.append(column)

View file

@ -1,10 +1,7 @@
import json import json
import numbers import numbers
from .utils import ( from .utils import detect_json1, escape_sqlite
detect_json1,
escape_sqlite,
)
class Filter: class Filter:
@ -20,7 +17,16 @@ class Filter:
class TemplatedFilter(Filter): class TemplatedFilter(Filter):
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): def __init__(
self,
key,
display,
sql_template,
human_template,
format="{}",
numeric=False,
no_argument=False,
):
self.key = key self.key = key
self.display = display self.display = display
self.sql_template = sql_template self.sql_template = sql_template
@ -34,16 +40,10 @@ class TemplatedFilter(Filter):
if self.numeric and converted.isdigit(): if self.numeric and converted.isdigit():
converted = int(converted) converted = int(converted)
if self.no_argument: if self.no_argument:
kwargs = { kwargs = {"c": column}
'c': column,
}
converted = None converted = None
else: else:
kwargs = { kwargs = {"c": column, "p": "p{}".format(param_counter), "t": table}
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
return self.sql_template.format(**kwargs), converted return self.sql_template.format(**kwargs), converted
def human_clause(self, column, value): def human_clause(self, column, value):
@ -58,8 +58,8 @@ class TemplatedFilter(Filter):
class InFilter(Filter): class InFilter(Filter):
key = 'in' key = "in"
display = 'in' display = "in"
def __init__(self): def __init__(self):
pass pass
@ -81,34 +81,98 @@ class InFilter(Filter):
class Filters: class Filters:
_filters = [ _filters = (
[
# key, display, sql_template, human_template, format=, numeric=, no_argument= # key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'), TemplatedFilter(
TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'), "exact",
TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'), "=",
TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'), '"{c}" = :{p}',
TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'), lambda c, v: "{c} = {v}" if v.isdigit() else '{c} = "{v}"',
TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True), ),
TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True), TemplatedFilter(
TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True), "not",
TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), "!=",
TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), '"{c}" != :{p}',
TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), lambda c, v: "{c} != {v}" if v.isdigit() else '{c} != "{v}"',
),
TemplatedFilter(
"contains",
"contains",
'"{c}" like :{p}',
'{c} contains "{v}"',
format="%{}%",
),
TemplatedFilter(
"endswith",
"ends with",
'"{c}" like :{p}',
'{c} ends with "{v}"',
format="%{}",
),
TemplatedFilter(
"startswith",
"starts with",
'"{c}" like :{p}',
'{c} starts with "{v}"',
format="{}%",
),
TemplatedFilter("gt", ">", '"{c}" > :{p}', "{c} > {v}", numeric=True),
TemplatedFilter(
"gte", "\u2265", '"{c}" >= :{p}', "{c} \u2265 {v}", numeric=True
),
TemplatedFilter("lt", "<", '"{c}" < :{p}', "{c} < {v}", numeric=True),
TemplatedFilter(
"lte", "\u2264", '"{c}" <= :{p}', "{c} \u2264 {v}", numeric=True
),
TemplatedFilter("like", "like", '"{c}" like :{p}', '{c} like "{v}"'),
TemplatedFilter("glob", "glob", '"{c}" glob :{p}', '{c} glob "{v}"'),
InFilter(), InFilter(),
] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in ( ]
+ (
[
TemplatedFilter(
"arraycontains",
"array contains",
"""rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p} where j.value = :{p}
)""", '{c} contains "{v}"') )""",
] if detect_json1() else []) + [ '{c} contains "{v}"',
TemplatedFilter('date', 'date', 'date({c}) = :{p}', '"{c}" is on date {v}'), )
TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
] ]
_filters_by_key = { if detect_json1()
f.key: f for f in _filters else []
} )
+ [
TemplatedFilter("date", "date", "date({c}) = :{p}", '"{c}" is on date {v}'),
TemplatedFilter(
"isnull", "is null", '"{c}" is null', "{c} is null", no_argument=True
),
TemplatedFilter(
"notnull",
"is not null",
'"{c}" is not null',
"{c} is not null",
no_argument=True,
),
TemplatedFilter(
"isblank",
"is blank",
'("{c}" is null or "{c}" = "")',
"{c} is blank",
no_argument=True,
),
TemplatedFilter(
"notblank",
"is not blank",
'("{c}" is not null and "{c}" != "")',
"{c} is not blank",
no_argument=True,
),
]
)
_filters_by_key = {f.key: f for f in _filters}
def __init__(self, pairs, units={}, ureg=None): def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs self.pairs = pairs
@ -132,22 +196,22 @@ class Filters:
and_bits = [] and_bits = []
commas, tail = bits[:-1], bits[-1:] commas, tail = bits[:-1], bits[-1:]
if commas: if commas:
and_bits.append(', '.join(commas)) and_bits.append(", ".join(commas))
if tail: if tail:
and_bits.append(tail[0]) and_bits.append(tail[0])
s = ' and '.join(and_bits) s = " and ".join(and_bits)
if not s: if not s:
return '' return ""
return 'where {}'.format(s) return "where {}".format(s)
def selections(self): def selections(self):
"Yields (column, lookup, value) tuples" "Yields (column, lookup, value) tuples"
for key, value in self.pairs: for key, value in self.pairs:
if '__' in key: if "__" in key:
column, lookup = key.rsplit('__', 1) column, lookup = key.rsplit("__", 1)
else: else:
column = key column = key
lookup = 'exact' lookup = "exact"
yield column, lookup, value yield column, lookup, value
def has_selections(self): def has_selections(self):
@ -174,13 +238,15 @@ class Filters:
for column, lookup, value in self.selections(): for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None) filter = self._filters_by_key.get(lookup, None)
if filter: if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i) sql_bit, param = filter.where_clause(
table, column, self.convert_unit(column, value), i
)
sql_bits.append(sql_bit) sql_bits.append(sql_bit)
if param is not None: if param is not None:
if not isinstance(param, list): if not isinstance(param, list):
param = [param] param = [param]
for individual_param in param: for individual_param in param:
param_id = 'p{}'.format(i) param_id = "p{}".format(i)
params[param_id] = individual_param params[param_id] = individual_param
i += 1 i += 1
return sql_bits, params return sql_bits, params

View file

@ -7,7 +7,7 @@ from .utils import (
escape_sqlite, escape_sqlite,
get_all_foreign_keys, get_all_foreign_keys,
table_columns, table_columns,
sqlite3 sqlite3,
) )
@ -29,7 +29,9 @@ def inspect_hash(path):
def inspect_views(conn): def inspect_views(conn):
" List views in a database. " " List views in a database. "
return [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')] return [
v[0] for v in conn.execute('select name from sqlite_master where type = "view"')
]
def inspect_tables(conn, database_metadata): def inspect_tables(conn, database_metadata):
@ -37,15 +39,11 @@ def inspect_tables(conn, database_metadata):
tables = {} tables = {}
table_names = [ table_names = [
r["name"] r["name"]
for r in conn.execute( for r in conn.execute('select * from sqlite_master where type="table"')
'select * from sqlite_master where type="table"'
)
] ]
for table in table_names: for table in table_names:
table_metadata = database_metadata.get("tables", {}).get( table_metadata = database_metadata.get("tables", {}).get(table, {})
table, {}
)
try: try:
count = conn.execute( count = conn.execute(

View file

@ -41,8 +41,12 @@ def publish_subcommand(publish):
name, name,
spatialite, spatialite,
): ):
fail_if_publish_binary_not_installed("gcloud", "Google Cloud", "https://cloud.google.com/sdk/") fail_if_publish_binary_not_installed(
project = check_output("gcloud config get-value project", shell=True, universal_newlines=True).strip() "gcloud", "Google Cloud", "https://cloud.google.com/sdk/"
)
project = check_output(
"gcloud config get-value project", shell=True, universal_newlines=True
).strip()
with temporary_docker_directory( with temporary_docker_directory(
files, files,
@ -68,4 +72,9 @@ def publish_subcommand(publish):
): ):
image_id = "gcr.io/{project}/{name}".format(project=project, name=name) image_id = "gcr.io/{project}/{name}".format(project=project, name=name)
check_call("gcloud builds submit --tag {}".format(image_id), shell=True) check_call("gcloud builds submit --tag {}".format(image_id), shell=True)
check_call("gcloud beta run deploy --allow-unauthenticated --image {}".format(image_id), shell=True) check_call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
image_id
),
shell=True,
)

View file

@ -5,7 +5,8 @@ import sys
def add_common_publish_arguments_and_options(subcommand): def add_common_publish_arguments_and_options(subcommand):
for decorator in reversed(( for decorator in reversed(
(
click.argument("files", type=click.Path(exists=True), nargs=-1), click.argument("files", type=click.Path(exists=True), nargs=-1),
click.option( click.option(
"-m", "-m",
@ -13,8 +14,12 @@ def add_common_publish_arguments_and_options(subcommand):
type=click.File(mode="r"), type=click.File(mode="r"),
help="Path to JSON file containing metadata to publish", help="Path to JSON file containing metadata to publish",
), ),
click.option("--extra-options", help="Extra options to pass to datasette serve"), click.option(
click.option("--branch", help="Install datasette from a GitHub branch e.g. master"), "--extra-options", help="Extra options to pass to datasette serve"
),
click.option(
"--branch", help="Install datasette from a GitHub branch e.g. master"
),
click.option( click.option(
"--template-dir", "--template-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True), type=click.Path(exists=True, file_okay=False, dir_okay=True),
@ -36,7 +41,9 @@ def add_common_publish_arguments_and_options(subcommand):
help="Additional packages (e.g. plugins) to install", help="Additional packages (e.g. plugins) to install",
multiple=True, multiple=True,
), ),
click.option("--version-note", help="Additional note to show on /-/versions"), click.option(
"--version-note", help="Additional note to show on /-/versions"
),
click.option("--title", help="Title for metadata"), click.option("--title", help="Title for metadata"),
click.option("--license", help="License label for metadata"), click.option("--license", help="License label for metadata"),
click.option("--license_url", help="License URL for metadata"), click.option("--license_url", help="License URL for metadata"),
@ -44,7 +51,8 @@ def add_common_publish_arguments_and_options(subcommand):
click.option("--source_url", help="Source URL for metadata"), click.option("--source_url", help="Source URL for metadata"),
click.option("--about", help="About label for metadata"), click.option("--about", help="About label for metadata"),
click.option("--about_url", help="About URL for metadata"), click.option("--about_url", help="About URL for metadata"),
)): )
):
subcommand = decorator(subcommand) subcommand = decorator(subcommand)
return subcommand return subcommand

View file

@ -76,9 +76,7 @@ def publish_subcommand(publish):
"about_url": about_url, "about_url": about_url,
}, },
): ):
now_json = { now_json = {"version": 1}
"version": 1
}
if alias: if alias:
now_json["alias"] = alias now_json["alias"] = alias
open("now.json", "w").write(json.dumps(now_json)) open("now.json", "w").write(json.dumps(now_json))

View file

@ -89,8 +89,4 @@ def json_renderer(args, data, view_name):
else: else:
body = json.dumps(data, cls=CustomJSONEncoder) body = json.dumps(data, cls=CustomJSONEncoder)
content_type = "application/json" content_type = "application/json"
return { return {"body": body, "status_code": status_code, "content_type": content_type}
"body": body,
"status_code": status_code,
"content_type": content_type
}

View file

@ -21,27 +21,29 @@ except ImportError:
import sqlite3 import sqlite3
# From https://www.sqlite.org/lang_keywords.html # From https://www.sqlite.org/lang_keywords.html
reserved_words = set(( reserved_words = set(
'abort action add after all alter analyze and as asc attach autoincrement ' (
'before begin between by cascade case cast check collate column commit ' "abort action add after all alter analyze and as asc attach autoincrement "
'conflict constraint create cross current_date current_time ' "before begin between by cascade case cast check collate column commit "
'current_timestamp database default deferrable deferred delete desc detach ' "conflict constraint create cross current_date current_time "
'distinct drop each else end escape except exclusive exists explain fail ' "current_timestamp database default deferrable deferred delete desc detach "
'for foreign from full glob group having if ignore immediate in index ' "distinct drop each else end escape except exclusive exists explain fail "
'indexed initially inner insert instead intersect into is isnull join key ' "for foreign from full glob group having if ignore immediate in index "
'left like limit match natural no not notnull null of offset on or order ' "indexed initially inner insert instead intersect into is isnull join key "
'outer plan pragma primary query raise recursive references regexp reindex ' "left like limit match natural no not notnull null of offset on or order "
'release rename replace restrict right rollback row savepoint select set ' "outer plan pragma primary query raise recursive references regexp reindex "
'table temp temporary then to transaction trigger union unique update using ' "release rename replace restrict right rollback row savepoint select set "
'vacuum values view virtual when where with without' "table temp temporary then to transaction trigger union unique update using "
).split()) "vacuum values view virtual when where with without"
).split()
)
SPATIALITE_DOCKERFILE_EXTRAS = r''' SPATIALITE_DOCKERFILE_EXTRAS = r"""
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y python3-dev gcc libsqlite3-mod-spatialite && \ apt-get install -y python3-dev gcc libsqlite3-mod-spatialite && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so
''' """
class InterruptedError(Exception): class InterruptedError(Exception):
@ -67,27 +69,24 @@ class Results:
def urlsafe_components(token): def urlsafe_components(token):
"Splits token on commas and URL decodes each component" "Splits token on commas and URL decodes each component"
return [ return [urllib.parse.unquote_plus(b) for b in token.split(",")]
urllib.parse.unquote_plus(b) for b in token.split(',')
]
def path_from_row_pks(row, pks, use_rowid, quote=True): def path_from_row_pks(row, pks, use_rowid, quote=True):
""" Generate an optionally URL-quoted unique identifier """ Generate an optionally URL-quoted unique identifier
for a row from its primary keys.""" for a row from its primary keys."""
if use_rowid: if use_rowid:
bits = [row['rowid']] bits = [row["rowid"]]
else: else:
bits = [ bits = [
row[pk]["value"] if isinstance(row[pk], dict) else row[pk] row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks
for pk in pks
] ]
if quote: if quote:
bits = [urllib.parse.quote_plus(str(bit)) for bit in bits] bits = [urllib.parse.quote_plus(str(bit)) for bit in bits]
else: else:
bits = [str(bit) for bit in bits] bits = [str(bit) for bit in bits]
return ','.join(bits) return ",".join(bits)
def compound_keys_after_sql(pks, start_index=0): def compound_keys_after_sql(pks, start_index=0):
@ -106,16 +105,17 @@ def compound_keys_after_sql(pks, start_index=0):
and_clauses = [] and_clauses = []
last = pks_left[-1] last = pks_left[-1]
rest = pks_left[:-1] rest = pks_left[:-1]
and_clauses = ['{} = :p{}'.format( and_clauses = [
escape_sqlite(pk), (i + start_index) "{} = :p{}".format(escape_sqlite(pk), (i + start_index))
) for i, pk in enumerate(rest)] for i, pk in enumerate(rest)
and_clauses.append('{} > :p{}'.format( ]
escape_sqlite(last), (len(rest) + start_index) and_clauses.append(
)) "{} > :p{}".format(escape_sqlite(last), (len(rest) + start_index))
or_clauses.append('({})'.format(' and '.join(and_clauses))) )
or_clauses.append("({})".format(" and ".join(and_clauses)))
pks_left.pop() pks_left.pop()
or_clauses.reverse() or_clauses.reverse()
return '({})'.format('\n or\n'.join(or_clauses)) return "({})".format("\n or\n".join(or_clauses))
class CustomJSONEncoder(json.JSONEncoder): class CustomJSONEncoder(json.JSONEncoder):
@ -127,11 +127,11 @@ class CustomJSONEncoder(json.JSONEncoder):
if isinstance(obj, bytes): if isinstance(obj, bytes):
# Does it encode to utf8? # Does it encode to utf8?
try: try:
return obj.decode('utf8') return obj.decode("utf8")
except UnicodeDecodeError: except UnicodeDecodeError:
return { return {
'$base64': True, "$base64": True,
'encoded': base64.b64encode(obj).decode('latin1'), "encoded": base64.b64encode(obj).decode("latin1"),
} }
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
@ -163,20 +163,18 @@ class InvalidSql(Exception):
allowed_sql_res = [ allowed_sql_res = [
re.compile(r'^select\b'), re.compile(r"^select\b"),
re.compile(r'^explain select\b'), re.compile(r"^explain select\b"),
re.compile(r'^explain query plan select\b'), re.compile(r"^explain query plan select\b"),
re.compile(r'^with\b'), re.compile(r"^with\b"),
]
disallawed_sql_res = [
(re.compile('pragma'), 'Statement may not contain PRAGMA'),
] ]
disallawed_sql_res = [(re.compile("pragma"), "Statement may not contain PRAGMA")]
def validate_sql_select(sql): def validate_sql_select(sql):
sql = sql.strip().lower() sql = sql.strip().lower()
if not any(r.match(sql) for r in allowed_sql_res): if not any(r.match(sql) for r in allowed_sql_res):
raise InvalidSql('Statement must be a SELECT') raise InvalidSql("Statement must be a SELECT")
for r, msg in disallawed_sql_res: for r, msg in disallawed_sql_res:
if r.search(sql): if r.search(sql):
raise InvalidSql(msg) raise InvalidSql(msg)
@ -184,9 +182,7 @@ def validate_sql_select(sql):
def append_querystring(url, querystring): def append_querystring(url, querystring):
op = "&" if ("?" in url) else "?" op = "&" if ("?" in url) else "?"
return "{}{}{}".format( return "{}{}{}".format(url, op, querystring)
url, op, querystring
)
def path_with_added_args(request, args, path=None): def path_with_added_args(request, args, path=None):
@ -198,14 +194,10 @@ def path_with_added_args(request, args, path=None):
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(request.query_string):
if key not in args_to_remove: if key not in args_to_remove:
current.append((key, value)) current.append((key, value))
current.extend([ current.extend([(key, value) for key, value in args if value is not None])
(key, value)
for key, value in args
if value is not None
])
query_string = urllib.parse.urlencode(current) query_string = urllib.parse.urlencode(current)
if query_string: if query_string:
query_string = '?{}'.format(query_string) query_string = "?{}".format(query_string)
return path + query_string return path + query_string
@ -220,18 +212,21 @@ def path_with_removed_args(request, args, path=None):
# args can be a dict or a set # args can be a dict or a set
current = [] current = []
if isinstance(args, set): if isinstance(args, set):
def should_remove(key, value): def should_remove(key, value):
return key in args return key in args
elif isinstance(args, dict): elif isinstance(args, dict):
# Must match key AND value # Must match key AND value
def should_remove(key, value): def should_remove(key, value):
return args.get(key) == value return args.get(key) == value
for key, value in urllib.parse.parse_qsl(query_string): for key, value in urllib.parse.parse_qsl(query_string):
if not should_remove(key, value): if not should_remove(key, value):
current.append((key, value)) current.append((key, value))
query_string = urllib.parse.urlencode(current) query_string = urllib.parse.urlencode(current)
if query_string: if query_string:
query_string = '?{}'.format(query_string) query_string = "?{}".format(query_string)
return path + query_string return path + query_string
@ -247,54 +242,66 @@ def path_with_replaced_args(request, args, path=None):
current.extend([p for p in args if p[1] is not None]) current.extend([p for p in args if p[1] is not None])
query_string = urllib.parse.urlencode(current) query_string = urllib.parse.urlencode(current)
if query_string: if query_string:
query_string = '?{}'.format(query_string) query_string = "?{}".format(query_string)
return path + query_string return path + query_string
_css_re = re.compile(r'''['"\n\\]''') _css_re = re.compile(r"""['"\n\\]""")
_boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') _boring_keyword_re = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def escape_css_string(s): def escape_css_string(s):
return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s) return _css_re.sub(lambda m: "\\{:X}".format(ord(m.group())), s)
def escape_sqlite(s): def escape_sqlite(s):
if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): if _boring_keyword_re.match(s) and (s.lower() not in reserved_words):
return s return s
else: else:
return '[{}]'.format(s) return "[{}]".format(s)
def make_dockerfile(files, metadata_file, extra_options, branch, template_dir, plugins_dir, static, install, spatialite, version_note):
cmd = ['datasette', 'serve', '--host', '0.0.0.0'] def make_dockerfile(
files,
metadata_file,
extra_options,
branch,
template_dir,
plugins_dir,
static,
install,
spatialite,
version_note,
):
cmd = ["datasette", "serve", "--host", "0.0.0.0"]
cmd.append('", "'.join(files)) cmd.append('", "'.join(files))
cmd.extend(['--cors', '--inspect-file', 'inspect-data.json']) cmd.extend(["--cors", "--inspect-file", "inspect-data.json"])
if metadata_file: if metadata_file:
cmd.extend(['--metadata', '{}'.format(metadata_file)]) cmd.extend(["--metadata", "{}".format(metadata_file)])
if template_dir: if template_dir:
cmd.extend(['--template-dir', 'templates/']) cmd.extend(["--template-dir", "templates/"])
if plugins_dir: if plugins_dir:
cmd.extend(['--plugins-dir', 'plugins/']) cmd.extend(["--plugins-dir", "plugins/"])
if version_note: if version_note:
cmd.extend(['--version-note', '{}'.format(version_note)]) cmd.extend(["--version-note", "{}".format(version_note)])
if static: if static:
for mount_point, _ in static: for mount_point, _ in static:
cmd.extend(['--static', '{}:{}'.format(mount_point, mount_point)]) cmd.extend(["--static", "{}:{}".format(mount_point, mount_point)])
if extra_options: if extra_options:
for opt in extra_options.split(): for opt in extra_options.split():
cmd.append('{}'.format(opt)) cmd.append("{}".format(opt))
cmd = [shlex.quote(part) for part in cmd] cmd = [shlex.quote(part) for part in cmd]
# port attribute is a (fixed) env variable and should not be quoted # port attribute is a (fixed) env variable and should not be quoted
cmd.extend(['--port', '$PORT']) cmd.extend(["--port", "$PORT"])
cmd = ' '.join(cmd) cmd = " ".join(cmd)
if branch: if branch:
install = ['https://github.com/simonw/datasette/archive/{}.zip'.format( install = [
branch "https://github.com/simonw/datasette/archive/{}.zip".format(branch)
)] + list(install) ] + list(install)
else: else:
install = ['datasette'] + list(install) install = ["datasette"] + list(install)
return ''' return """
FROM python:3.6 FROM python:3.6
COPY . /app COPY . /app
WORKDIR /app WORKDIR /app
@ -303,11 +310,11 @@ RUN pip install -U {install_from}
RUN datasette inspect {files} --inspect-file inspect-data.json RUN datasette inspect {files} --inspect-file inspect-data.json
ENV PORT 8001 ENV PORT 8001
EXPOSE 8001 EXPOSE 8001
CMD {cmd}'''.format( CMD {cmd}""".format(
files=' '.join(files), files=" ".join(files),
cmd=cmd, cmd=cmd,
install_from=' '.join(install), install_from=" ".join(install),
spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else '', spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else "",
).strip() ).strip()
@ -324,7 +331,7 @@ def temporary_docker_directory(
install, install,
spatialite, spatialite,
version_note, version_note,
extra_metadata=None extra_metadata=None,
): ):
extra_metadata = extra_metadata or {} extra_metadata = extra_metadata or {}
tmp = tempfile.TemporaryDirectory() tmp = tempfile.TemporaryDirectory()
@ -332,10 +339,7 @@ def temporary_docker_directory(
datasette_dir = os.path.join(tmp.name, name) datasette_dir = os.path.join(tmp.name, name)
os.mkdir(datasette_dir) os.mkdir(datasette_dir)
saved_cwd = os.getcwd() saved_cwd = os.getcwd()
file_paths = [ file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_names = [os.path.split(f)[-1] for f in files] file_names = [os.path.split(f)[-1] for f in files]
if metadata: if metadata:
metadata_content = json.load(metadata) metadata_content = json.load(metadata)
@ -347,7 +351,7 @@ def temporary_docker_directory(
try: try:
dockerfile = make_dockerfile( dockerfile = make_dockerfile(
file_names, file_names,
metadata_content and 'metadata.json', metadata_content and "metadata.json",
extra_options, extra_options,
branch, branch,
template_dir, template_dir,
@ -359,24 +363,23 @@ def temporary_docker_directory(
) )
os.chdir(datasette_dir) os.chdir(datasette_dir)
if metadata_content: if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2)) open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open('Dockerfile', 'w').write(dockerfile) open("Dockerfile", "w").write(dockerfile)
for path, filename in zip(file_paths, file_names): for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(datasette_dir, filename)) link_or_copy(path, os.path.join(datasette_dir, filename))
if template_dir: if template_dir:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, template_dir), os.path.join(saved_cwd, template_dir),
os.path.join(datasette_dir, 'templates') os.path.join(datasette_dir, "templates"),
) )
if plugins_dir: if plugins_dir:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir), os.path.join(saved_cwd, plugins_dir),
os.path.join(datasette_dir, 'plugins') os.path.join(datasette_dir, "plugins"),
) )
for mount_point, path in static: for mount_point, path in static:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, path), os.path.join(saved_cwd, path), os.path.join(datasette_dir, mount_point)
os.path.join(datasette_dir, mount_point)
) )
yield datasette_dir yield datasette_dir
finally: finally:
@ -396,7 +399,7 @@ def temporary_heroku_directory(
static, static,
install, install,
version_note, version_note,
extra_metadata=None extra_metadata=None,
): ):
# FIXME: lots of duplicated code from above # FIXME: lots of duplicated code from above
@ -404,10 +407,7 @@ def temporary_heroku_directory(
tmp = tempfile.TemporaryDirectory() tmp = tempfile.TemporaryDirectory()
saved_cwd = os.getcwd() saved_cwd = os.getcwd()
file_paths = [ file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_names = [os.path.split(f)[-1] for f in files] file_names = [os.path.split(f)[-1] for f in files]
if metadata: if metadata:
@ -422,53 +422,54 @@ def temporary_heroku_directory(
os.chdir(tmp.name) os.chdir(tmp.name)
if metadata_content: if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2)) open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open('runtime.txt', 'w').write('python-3.6.7') open("runtime.txt", "w").write("python-3.6.7")
if branch: if branch:
install = ['https://github.com/simonw/datasette/archive/{branch}.zip'.format( install = [
"https://github.com/simonw/datasette/archive/{branch}.zip".format(
branch=branch branch=branch
)] + list(install) )
] + list(install)
else: else:
install = ['datasette'] + list(install) install = ["datasette"] + list(install)
open('requirements.txt', 'w').write('\n'.join(install)) open("requirements.txt", "w").write("\n".join(install))
os.mkdir('bin') os.mkdir("bin")
open('bin/post_compile', 'w').write('datasette inspect --inspect-file inspect-data.json') open("bin/post_compile", "w").write(
"datasette inspect --inspect-file inspect-data.json"
)
extras = [] extras = []
if template_dir: if template_dir:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, template_dir), os.path.join(saved_cwd, template_dir),
os.path.join(tmp.name, 'templates') os.path.join(tmp.name, "templates"),
) )
extras.extend(['--template-dir', 'templates/']) extras.extend(["--template-dir", "templates/"])
if plugins_dir: if plugins_dir:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir), os.path.join(saved_cwd, plugins_dir), os.path.join(tmp.name, "plugins")
os.path.join(tmp.name, 'plugins')
) )
extras.extend(['--plugins-dir', 'plugins/']) extras.extend(["--plugins-dir", "plugins/"])
if version_note: if version_note:
extras.extend(['--version-note', version_note]) extras.extend(["--version-note", version_note])
if metadata_content: if metadata_content:
extras.extend(['--metadata', 'metadata.json']) extras.extend(["--metadata", "metadata.json"])
if extra_options: if extra_options:
extras.extend(extra_options.split()) extras.extend(extra_options.split())
for mount_point, path in static: for mount_point, path in static:
link_or_copy_directory( link_or_copy_directory(
os.path.join(saved_cwd, path), os.path.join(saved_cwd, path), os.path.join(tmp.name, mount_point)
os.path.join(tmp.name, mount_point)
) )
extras.extend(['--static', '{}:{}'.format(mount_point, mount_point)]) extras.extend(["--static", "{}:{}".format(mount_point, mount_point)])
quoted_files = " ".join(map(shlex.quote, file_names)) quoted_files = " ".join(map(shlex.quote, file_names))
procfile_cmd = 'web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}'.format( procfile_cmd = "web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}".format(
quoted_files=quoted_files, quoted_files=quoted_files, extras=" ".join(extras)
extras=' '.join(extras),
) )
open('Procfile', 'w').write(procfile_cmd) open("Procfile", "w").write(procfile_cmd)
for path, filename in zip(file_paths, file_names): for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(tmp.name, filename)) link_or_copy(path, os.path.join(tmp.name, filename))
@ -484,9 +485,7 @@ def detect_primary_keys(conn, table):
" Figure out primary keys for a table. " " Figure out primary keys for a table. "
table_info_rows = [ table_info_rows = [
row row
for row in conn.execute( for row in conn.execute('PRAGMA table_info("{}")'.format(table)).fetchall()
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1] if row[-1]
] ]
table_info_rows.sort(key=lambda row: row[-1]) table_info_rows.sort(key=lambda row: row[-1])
@ -494,33 +493,26 @@ def detect_primary_keys(conn, table):
def get_outbound_foreign_keys(conn, table): def get_outbound_foreign_keys(conn, table):
infos = conn.execute( infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
fks = [] fks = []
for info in infos: for info in infos:
if info is not None: if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info id, seq, table_name, from_, to_, on_update, on_delete, match = info
fks.append({ fks.append(
'other_table': table_name, {"other_table": table_name, "column": from_, "other_column": to_}
'column': from_, )
'other_column': to_
})
return fks return fks
def get_all_foreign_keys(conn): def get_all_foreign_keys(conn):
tables = [r[0] for r in conn.execute('select name from sqlite_master where type="table"')] tables = [
r[0] for r in conn.execute('select name from sqlite_master where type="table"')
]
table_to_foreign_keys = {} table_to_foreign_keys = {}
for table in tables: for table in tables:
table_to_foreign_keys[table] = { table_to_foreign_keys[table] = {"incoming": [], "outgoing": []}
'incoming': [],
'outgoing': [],
}
for table in tables: for table in tables:
infos = conn.execute( infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
for info in infos: for info in infos:
if info is not None: if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info id, seq, table_name, from_, to_, on_update, on_delete, match = info
@ -528,22 +520,20 @@ def get_all_foreign_keys(conn):
# Weird edge case where something refers to a table that does # Weird edge case where something refers to a table that does
# not actually exist # not actually exist
continue continue
table_to_foreign_keys[table_name]['incoming'].append({ table_to_foreign_keys[table_name]["incoming"].append(
'other_table': table, {"other_table": table, "column": to_, "other_column": from_}
'column': to_, )
'other_column': from_ table_to_foreign_keys[table]["outgoing"].append(
}) {"other_table": table_name, "column": from_, "other_column": to_}
table_to_foreign_keys[table]['outgoing'].append({ )
'other_table': table_name,
'column': from_,
'other_column': to_
})
return table_to_foreign_keys return table_to_foreign_keys
def detect_spatialite(conn): def detect_spatialite(conn):
rows = conn.execute('select 1 from sqlite_master where tbl_name = "geometry_columns"').fetchall() rows = conn.execute(
'select 1 from sqlite_master where tbl_name = "geometry_columns"'
).fetchall()
return len(rows) > 0 return len(rows) > 0
@ -557,7 +547,7 @@ def detect_fts(conn, table):
def detect_fts_sql(table): def detect_fts_sql(table):
return r''' return r"""
select name from sqlite_master select name from sqlite_master
where rootpage = 0 where rootpage = 0
and ( and (
@ -567,7 +557,9 @@ def detect_fts_sql(table):
and sql like '%VIRTUAL TABLE%USING FTS%' and sql like '%VIRTUAL TABLE%USING FTS%'
) )
) )
'''.format(table=table) """.format(
table=table
)
def detect_json1(conn=None): def detect_json1(conn=None):
@ -589,51 +581,53 @@ def table_columns(conn, table):
] ]
filter_column_re = re.compile(r'^_filter_column_\d+$') filter_column_re = re.compile(r"^_filter_column_\d+$")
def filters_should_redirect(special_args): def filters_should_redirect(special_args):
redirect_params = [] redirect_params = []
# Handle _filter_column=foo&_filter_op=exact&_filter_value=... # Handle _filter_column=foo&_filter_op=exact&_filter_value=...
filter_column = special_args.get('_filter_column') filter_column = special_args.get("_filter_column")
filter_op = special_args.get('_filter_op') or '' filter_op = special_args.get("_filter_op") or ""
filter_value = special_args.get('_filter_value') or '' filter_value = special_args.get("_filter_value") or ""
if '__' in filter_op: if "__" in filter_op:
filter_op, filter_value = filter_op.split('__', 1) filter_op, filter_value = filter_op.split("__", 1)
if filter_column: if filter_column:
redirect_params.append( redirect_params.append(
('{}__{}'.format(filter_column, filter_op), filter_value) ("{}__{}".format(filter_column, filter_op), filter_value)
) )
for key in ('_filter_column', '_filter_op', '_filter_value'): for key in ("_filter_column", "_filter_op", "_filter_value"):
if key in special_args: if key in special_args:
redirect_params.append((key, None)) redirect_params.append((key, None))
# Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello # Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello
column_keys = [k for k in special_args if filter_column_re.match(k)] column_keys = [k for k in special_args if filter_column_re.match(k)]
for column_key in column_keys: for column_key in column_keys:
number = column_key.split('_')[-1] number = column_key.split("_")[-1]
column = special_args[column_key] column = special_args[column_key]
op = special_args.get('_filter_op_{}'.format(number)) or 'exact' op = special_args.get("_filter_op_{}".format(number)) or "exact"
value = special_args.get('_filter_value_{}'.format(number)) or '' value = special_args.get("_filter_value_{}".format(number)) or ""
if '__' in op: if "__" in op:
op, value = op.split('__', 1) op, value = op.split("__", 1)
if column: if column:
redirect_params.append(('{}__{}'.format(column, op), value)) redirect_params.append(("{}__{}".format(column, op), value))
redirect_params.extend([ redirect_params.extend(
('_filter_column_{}'.format(number), None), [
('_filter_op_{}'.format(number), None), ("_filter_column_{}".format(number), None),
('_filter_value_{}'.format(number), None), ("_filter_op_{}".format(number), None),
]) ("_filter_value_{}".format(number), None),
]
)
return redirect_params return redirect_params
whitespace_re = re.compile(r'\s') whitespace_re = re.compile(r"\s")
def is_url(value): def is_url(value):
"Must start with http:// or https:// and contain JUST a URL" "Must start with http:// or https:// and contain JUST a URL"
if not isinstance(value, str): if not isinstance(value, str):
return False return False
if not value.startswith('http://') and not value.startswith('https://'): if not value.startswith("http://") and not value.startswith("https://"):
return False return False
# Any whitespace at all is invalid # Any whitespace at all is invalid
if whitespace_re.search(value): if whitespace_re.search(value):
@ -641,8 +635,8 @@ def is_url(value):
return True return True
css_class_re = re.compile(r'^[a-zA-Z]+[_a-zA-Z0-9-]*$') css_class_re = re.compile(r"^[a-zA-Z]+[_a-zA-Z0-9-]*$")
css_invalid_chars_re = re.compile(r'[^a-zA-Z0-9_\-]') css_invalid_chars_re = re.compile(r"[^a-zA-Z0-9_\-]")
def to_css_class(s): def to_css_class(s):
@ -656,16 +650,16 @@ def to_css_class(s):
""" """
if css_class_re.match(s): if css_class_re.match(s):
return s return s
md5_suffix = hashlib.md5(s.encode('utf8')).hexdigest()[:6] md5_suffix = hashlib.md5(s.encode("utf8")).hexdigest()[:6]
# Strip leading _, - # Strip leading _, -
s = s.lstrip('_').lstrip('-') s = s.lstrip("_").lstrip("-")
# Replace any whitespace with hyphens # Replace any whitespace with hyphens
s = '-'.join(s.split()) s = "-".join(s.split())
# Remove any remaining invalid characters # Remove any remaining invalid characters
s = css_invalid_chars_re.sub('', s) s = css_invalid_chars_re.sub("", s)
# Attach the md5 suffix # Attach the md5 suffix
bits = [b for b in (s, md5_suffix) if b] bits = [b for b in (s, md5_suffix) if b]
return '-'.join(bits) return "-".join(bits)
def link_or_copy(src, dst): def link_or_copy(src, dst):
@ -689,8 +683,8 @@ def module_from_path(path, name):
# Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html # Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html
mod = imp.new_module(name) mod = imp.new_module(name)
mod.__file__ = path mod.__file__ = path
with open(path, 'r') as file: with open(path, "r") as file:
code = compile(file.read(), path, 'exec', dont_inherit=True) code = compile(file.read(), path, "exec", dont_inherit=True)
exec(code, mod.__dict__) exec(code, mod.__dict__)
return mod return mod
@ -702,34 +696,36 @@ def get_plugins(pm):
static_path = None static_path = None
templates_path = None templates_path = None
try: try:
if pkg_resources.resource_isdir(plugin.__name__, 'static'): if pkg_resources.resource_isdir(plugin.__name__, "static"):
static_path = pkg_resources.resource_filename(plugin.__name__, 'static') static_path = pkg_resources.resource_filename(plugin.__name__, "static")
if pkg_resources.resource_isdir(plugin.__name__, 'templates'): if pkg_resources.resource_isdir(plugin.__name__, "templates"):
templates_path = pkg_resources.resource_filename(plugin.__name__, 'templates') templates_path = pkg_resources.resource_filename(
plugin.__name__, "templates"
)
except (KeyError, ImportError): except (KeyError, ImportError):
# Caused by --plugins_dir= plugins - KeyError/ImportError thrown in Py3.5 # Caused by --plugins_dir= plugins - KeyError/ImportError thrown in Py3.5
pass pass
plugin_info = { plugin_info = {
'name': plugin.__name__, "name": plugin.__name__,
'static_path': static_path, "static_path": static_path,
'templates_path': templates_path, "templates_path": templates_path,
} }
distinfo = plugin_to_distinfo.get(plugin) distinfo = plugin_to_distinfo.get(plugin)
if distinfo: if distinfo:
plugin_info['version'] = distinfo.version plugin_info["version"] = distinfo.version
plugins.append(plugin_info) plugins.append(plugin_info)
return plugins return plugins
async def resolve_table_and_format(table_and_format, table_exists, allowed_formats=[]): async def resolve_table_and_format(table_and_format, table_exists, allowed_formats=[]):
if '.' in table_and_format: if "." in table_and_format:
# Check if a table exists with this exact name # Check if a table exists with this exact name
it_exists = await table_exists(table_and_format) it_exists = await table_exists(table_and_format)
if it_exists: if it_exists:
return table_and_format, None return table_and_format, None
# Check if table ends with a known format # Check if table ends with a known format
formats = list(allowed_formats) + ['csv', 'jsono'] formats = list(allowed_formats) + ["csv", "jsono"]
for _format in formats: for _format in formats:
if table_and_format.endswith(".{}".format(_format)): if table_and_format.endswith(".{}".format(_format)):
table = table_and_format[: -(len(_format) + 1)] table = table_and_format[: -(len(_format) + 1)]
@ -747,9 +743,7 @@ def path_with_format(request, format, extra_qs=None):
if qs: if qs:
extra = urllib.parse.urlencode(sorted(qs.items())) extra = urllib.parse.urlencode(sorted(qs.items()))
if request.query_string: if request.query_string:
path = "{}?{}&{}".format( path = "{}?{}&{}".format(path, request.query_string, extra)
path, request.query_string, extra
)
else: else:
path = "{}?{}".format(path, extra) path = "{}?{}".format(path, extra)
elif request.query_string: elif request.query_string:
@ -777,9 +771,9 @@ class CustomRow(OrderedDict):
def value_as_boolean(value): def value_as_boolean(value):
if value.lower() not in ('on', 'off', 'true', 'false', '1', '0'): if value.lower() not in ("on", "off", "true", "false", "1", "0"):
raise ValueAsBooleanError raise ValueAsBooleanError
return value.lower() in ('on', 'true', '1') return value.lower() in ("on", "true", "1")
class ValueAsBooleanError(ValueError): class ValueAsBooleanError(ValueError):
@ -799,9 +793,9 @@ class LimitedWriter:
def write(self, bytes): def write(self, bytes):
self.bytes_count += len(bytes) self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes): if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded("CSV contains more than {} bytes".format( raise WriteLimitExceeded(
self.limit_bytes "CSV contains more than {} bytes".format(self.limit_bytes)
)) )
self.writer.write(bytes) self.writer.write(bytes)
@ -810,10 +804,7 @@ _infinities = {float("inf"), float("-inf")}
def remove_infinites(row): def remove_infinites(row):
if any((c in _infinities) if isinstance(c, float) else 0 for c in row): if any((c in _infinities) if isinstance(c, float) else 0 for c in row):
return [ return [None if (isinstance(c, float) and c in _infinities) else c for c in row]
None if (isinstance(c, float) and c in _infinities) else c
for c in row
]
return row return row
@ -824,7 +815,8 @@ class StaticMount(click.ParamType):
if ":" not in value: if ":" not in value:
self.fail( self.fail(
'"{}" should be of format mountpoint:directory'.format(value), '"{}" should be of format mountpoint:directory'.format(value),
param, ctx param,
ctx,
) )
path, dirpath = value.split(":") path, dirpath = value.split(":")
if not os.path.exists(dirpath) or not os.path.isdir(dirpath): if not os.path.exists(dirpath) or not os.path.isdir(dirpath):

View file

@ -1,6 +1,6 @@
from ._version import get_versions from ._version import get_versions
__version__ = get_versions()['version'] __version__ = get_versions()["version"]
del get_versions del get_versions
__version_info__ = tuple(__version__.split(".")) __version_info__ = tuple(__version__.split("."))

View file

@ -33,8 +33,15 @@ HASH_LENGTH = 7
class DatasetteError(Exception): class DatasetteError(Exception):
def __init__(
def __init__(self, message, title=None, error_dict=None, status=500, template=None, messagge_is_html=False): self,
message,
title=None,
error_dict=None,
status=500,
template=None,
messagge_is_html=False,
):
self.message = message self.message = message
self.title = title self.title = title
self.error_dict = error_dict or {} self.error_dict = error_dict or {}
@ -43,18 +50,19 @@ class DatasetteError(Exception):
class RenderMixin(HTTPMethodView): class RenderMixin(HTTPMethodView):
def _asset_urls(self, key, template, context): def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins: # Flatten list-of-lists from plugins:
seen_urls = set() seen_urls = set()
for url_or_dict in itertools.chain( for url_or_dict in itertools.chain(
itertools.chain.from_iterable(getattr(pm.hook, key)( itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name, template=template.name,
database=context.get("database"), database=context.get("database"),
table=context.get("table"), table=context.get("table"),
datasette=self.ds datasette=self.ds,
)), )
(self.ds.metadata(key) or []) ),
(self.ds.metadata(key) or []),
): ):
if isinstance(url_or_dict, dict): if isinstance(url_or_dict, dict):
url = url_or_dict["url"] url = url_or_dict["url"]
@ -73,14 +81,12 @@ class RenderMixin(HTTPMethodView):
def database_url(self, database): def database_url(self, database):
db = self.ds.databases[database] db = self.ds.databases[database]
if self.ds.config("hash_urls") and db.hash: if self.ds.config("hash_urls") and db.hash:
return "/{}-{}".format( return "/{}-{}".format(database, db.hash[:HASH_LENGTH])
database, db.hash[:HASH_LENGTH]
)
else: else:
return "/{}".format(database) return "/{}".format(database)
def database_color(self, database): def database_color(self, database):
return 'ff0000' return "ff0000"
def render(self, templates, **context): def render(self, templates, **context):
template = self.ds.jinja_env.select_template(templates) template = self.ds.jinja_env.select_template(templates)
@ -95,7 +101,7 @@ class RenderMixin(HTTPMethodView):
database=context.get("database"), database=context.get("database"),
table=context.get("table"), table=context.get("table"),
view_name=self.name, view_name=self.name,
datasette=self.ds datasette=self.ds,
): ):
body_scripts.append(jinja2.Markup(script)) body_scripts.append(jinja2.Markup(script))
return response.html( return response.html(
@ -116,14 +122,14 @@ class RenderMixin(HTTPMethodView):
"format_bytes": format_bytes, "format_bytes": format_bytes,
"database_url": self.database_url, "database_url": self.database_url,
"database_color": self.database_color, "database_color": self.database_color,
} },
} }
) )
) )
class BaseView(RenderMixin): class BaseView(RenderMixin):
name = '' name = ""
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
def __init__(self, datasette): def __init__(self, datasette):
@ -171,32 +177,30 @@ class BaseView(RenderMixin):
expected = "000" expected = "000"
if db.hash is not None: if db.hash is not None:
expected = db.hash[:HASH_LENGTH] expected = db.hash[:HASH_LENGTH]
correct_hash_provided = (expected == hash) correct_hash_provided = expected == hash
if not correct_hash_provided: if not correct_hash_provided:
if "table_and_format" in kwargs: if "table_and_format" in kwargs:
async def async_table_exists(t): async def async_table_exists(t):
return await self.ds.table_exists(name, t) return await self.ds.table_exists(name, t)
table, _format = await resolve_table_and_format( table, _format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus( table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"] kwargs["table_and_format"]
), ),
table_exists=async_table_exists, table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys() allowed_formats=self.ds.renderers.keys(),
) )
kwargs["table"] = table kwargs["table"] = table
if _format: if _format:
kwargs["as_format"] = ".{}".format(_format) kwargs["as_format"] = ".{}".format(_format)
elif "table" in kwargs: elif "table" in kwargs:
kwargs["table"] = urllib.parse.unquote_plus( kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
kwargs["table"]
)
should_redirect = "/{}-{}".format(name, expected) should_redirect = "/{}-{}".format(name, expected)
if "table" in kwargs: if "table" in kwargs:
should_redirect += "/" + urllib.parse.quote_plus( should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
kwargs["table"]
)
if "pk_path" in kwargs: if "pk_path" in kwargs:
should_redirect += "/" + kwargs["pk_path"] should_redirect += "/" + kwargs["pk_path"]
if "as_format" in kwargs: if "as_format" in kwargs:
@ -219,7 +223,9 @@ class BaseView(RenderMixin):
if should_redirect: if should_redirect:
return self.redirect(request, should_redirect, remove_args={"_hash"}) return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(request, database, hash, correct_hash_provided, **kwargs) return await self.view_get(
request, database, hash, correct_hash_provided, **kwargs
)
async def as_csv(self, request, database, hash, **kwargs): async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream") stream = request.args.get("_stream")
@ -228,9 +234,7 @@ class BaseView(RenderMixin):
if not self.ds.config("allow_csv_stream"): if not self.ds.config("allow_csv_stream"):
raise DatasetteError("CSV streaming is disabled", status=400) raise DatasetteError("CSV streaming is disabled", status=400)
if request.args.get("_next"): if request.args.get("_next"):
raise DatasetteError( raise DatasetteError("_next not allowed for CSV streaming", status=400)
"_next not allowed for CSV streaming", status=400
)
kwargs["_size"] = "max" kwargs["_size"] = "max"
# Fetch the first page # Fetch the first page
try: try:
@ -271,9 +275,7 @@ class BaseView(RenderMixin):
if next: if next:
kwargs["_next"] = next kwargs["_next"] = next
if not first: if not first:
data, _, _ = await self.data( data, _, _ = await self.data(request, database, hash, **kwargs)
request, database, hash, **kwargs
)
if first: if first:
writer.writerow(headings) writer.writerow(headings)
first = False first = False
@ -293,7 +295,7 @@ class BaseView(RenderMixin):
new_row.append(cell) new_row.append(cell)
writer.writerow(new_row) writer.writerow(new_row)
except Exception as e: except Exception as e:
print('caught this', e) print("caught this", e)
r.write(str(e)) r.write(str(e))
return return
@ -304,15 +306,11 @@ class BaseView(RenderMixin):
if request.args.get("_dl", None): if request.args.get("_dl", None):
content_type = "text/csv; charset=utf-8" content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format( disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', database) kwargs.get("table", database)
) )
headers["Content-Disposition"] = disposition headers["Content-Disposition"] = disposition
return response.stream( return response.stream(stream_fn, headers=headers, content_type=content_type)
stream_fn,
headers=headers,
content_type=content_type
)
async def get_format(self, request, database, args): async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL """ Determine the format of the response from the request, from URL
@ -325,22 +323,20 @@ class BaseView(RenderMixin):
if not _format: if not _format:
_format = (args.pop("as_format", None) or "").lstrip(".") _format = (args.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in args: if "table_and_format" in args:
async def async_table_exists(t): async def async_table_exists(t):
return await self.ds.table_exists(database, t) return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format( table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus( table_and_format=urllib.parse.unquote_plus(args["table_and_format"]),
args["table_and_format"]
),
table_exists=async_table_exists, table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys() allowed_formats=self.ds.renderers.keys(),
) )
_format = _format or _ext_format _format = _format or _ext_format
args["table"] = table args["table"] = table
del args["table_and_format"] del args["table_and_format"]
elif "table" in args: elif "table" in args:
args["table"] = urllib.parse.unquote_plus( args["table"] = urllib.parse.unquote_plus(args["table"])
args["table"]
)
return _format, args return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs): async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
@ -351,7 +347,7 @@ class BaseView(RenderMixin):
if _format is None: if _format is None:
# HTML views default to expanding all foriegn key labels # HTML views default to expanding all foriegn key labels
kwargs['default_labels'] = True kwargs["default_labels"] = True
extra_template_data = {} extra_template_data = {}
start = time.time() start = time.time()
@ -367,11 +363,16 @@ class BaseView(RenderMixin):
else: else:
data, extra_template_data, templates = response_or_template_contexts data, extra_template_data, templates = response_or_template_contexts
except InterruptedError: except InterruptedError:
raise DatasetteError(""" raise DatasetteError(
"""
SQL query took too long. The time limit is controlled by the SQL query took too long. The time limit is controlled by the
<a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a> <a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
configuration option. configuration option.
""", title="SQL Interrupted", status=400, messagge_is_html=True) """,
title="SQL Interrupted",
status=400,
messagge_is_html=True,
)
except (sqlite3.OperationalError, InvalidSql) as e: except (sqlite3.OperationalError, InvalidSql) as e:
raise DatasetteError(str(e), title="Invalid SQL", status=400) raise DatasetteError(str(e), title="Invalid SQL", status=400)
@ -408,14 +409,14 @@ class BaseView(RenderMixin):
raise NotFound("No data") raise NotFound("No data")
response_args = { response_args = {
'content_type': result.get('content_type', 'text/plain'), "content_type": result.get("content_type", "text/plain"),
'status': result.get('status_code', 200) "status": result.get("status_code", 200),
} }
if type(result.get('body')) == bytes: if type(result.get("body")) == bytes:
response_args['body_bytes'] = result.get('body') response_args["body_bytes"] = result.get("body")
else: else:
response_args['body'] = result.get('body') response_args["body"] = result.get("body")
r = response.HTTPResponse(**response_args) r = response.HTTPResponse(**response_args)
else: else:
@ -431,14 +432,12 @@ class BaseView(RenderMixin):
url_labels_extra = {"_labels": "on"} url_labels_extra = {"_labels": "on"}
renderers = { renderers = {
key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys() key: path_with_format(request, key, {**url_labels_extra})
} for key in self.ds.renderers.keys()
url_csv_args = {
"_size": "max",
**url_labels_extra
} }
url_csv_args = {"_size": "max", **url_labels_extra}
url_csv = path_with_format(request, "csv", url_csv_args) url_csv = path_with_format(request, "csv", url_csv_args)
url_csv_path = url_csv.split('?')[0] url_csv_path = url_csv.split("?")[0]
context = { context = {
**data, **data,
**extras, **extras,
@ -450,10 +449,11 @@ class BaseView(RenderMixin):
(key, value) (key, value)
for key, value in urllib.parse.parse_qsl(request.query_string) for key, value in urllib.parse.parse_qsl(request.query_string)
if key not in ("_labels", "_facet", "_size") if key not in ("_labels", "_facet", "_size")
] + [("_size", "max")], ]
+ [("_size", "max")],
"datasette_version": __version__, "datasette_version": __version__,
"config": self.ds.config_dict(), "config": self.ds.config_dict(),
} },
} }
if "metadata" not in context: if "metadata" not in context:
context["metadata"] = self.ds.metadata context["metadata"] = self.ds.metadata
@ -474,9 +474,9 @@ class BaseView(RenderMixin):
if self.ds.cache_headers and response.status == 200: if self.ds.cache_headers and response.status == 200:
ttl = int(ttl) ttl = int(ttl)
if ttl == 0: if ttl == 0:
ttl_header = 'no-cache' ttl_header = "no-cache"
else: else:
ttl_header = 'max-age={}'.format(ttl) ttl_header = "max-age={}".format(ttl)
response.headers["Cache-Control"] = ttl_header response.headers["Cache-Control"] = ttl_header
response.headers["Referrer-Policy"] = "no-referrer" response.headers["Referrer-Policy"] = "no-referrer"
if self.ds.cors: if self.ds.cors:
@ -484,8 +484,15 @@ class BaseView(RenderMixin):
return response return response
async def custom_sql( async def custom_sql(
self, request, database, hash, sql, editable=True, canned_query=None, self,
metadata=None, _size=None request,
database,
hash,
sql,
editable=True,
canned_query=None,
metadata=None,
_size=None,
): ):
params = request.raw_args params = request.raw_args
if "sql" in params: if "sql" in params:
@ -565,10 +572,14 @@ class BaseView(RenderMixin):
"hide_sql": "_hide_sql" in params, "hide_sql": "_hide_sql" in params,
} }
return { return (
{
"database": database, "database": database,
"rows": results.rows, "rows": results.rows,
"truncated": results.truncated, "truncated": results.truncated,
"columns": columns, "columns": columns,
"query": {"sql": sql, "params": params}, "query": {"sql": sql, "params": params},
}, extra_template, templates },
extra_template,
templates,
)

View file

@ -15,7 +15,7 @@ from .base import HASH_LENGTH, RenderMixin
class IndexView(RenderMixin): class IndexView(RenderMixin):
name = 'index' name = "index"
def __init__(self, datasette): def __init__(self, datasette):
self.ds = datasette self.ds = datasette
@ -43,23 +43,25 @@ class IndexView(RenderMixin):
} }
hidden_tables = [t for t in tables.values() if t["hidden"]] hidden_tables = [t for t in tables.values() if t["hidden"]]
databases.append({ databases.append(
{
"name": name, "name": name,
"hash": db.hash, "hash": db.hash,
"color": db.hash[:6] if db.hash else hashlib.md5(name.encode("utf8")).hexdigest()[:6], "color": db.hash[:6]
if db.hash
else hashlib.md5(name.encode("utf8")).hexdigest()[:6],
"path": self.database_url(name), "path": self.database_url(name),
"tables_truncated": sorted( "tables_truncated": sorted(
tables.values(), key=lambda t: t["count"] or 0, reverse=True tables.values(), key=lambda t: t["count"] or 0, reverse=True
)[ )[:5],
:5
],
"tables_count": len(tables), "tables_count": len(tables),
"tables_more": len(tables) > 5, "tables_more": len(tables) > 5,
"table_rows_sum": sum((t["count"] or 0) for t in tables.values()), "table_rows_sum": sum((t["count"] or 0) for t in tables.values()),
"hidden_table_rows_sum": sum(t["count"] for t in hidden_tables), "hidden_table_rows_sum": sum(t["count"] for t in hidden_tables),
"hidden_tables_count": len(hidden_tables), "hidden_tables_count": len(hidden_tables),
"views_count": len(views), "views_count": len(views),
}) }
)
if as_format: if as_format:
headers = {} headers = {}
if self.ds.cors: if self.ds.cors:

View file

@ -18,14 +18,8 @@ class JsonDataView(RenderMixin):
if self.ds.cors: if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*" headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse( return response.HTTPResponse(
json.dumps(data), json.dumps(data), content_type="application/json", headers=headers
content_type="application/json",
headers=headers
) )
else: else:
return self.render( return self.render(["show_json.html"], filename=self.filename, data=data)
["show_json.html"],
filename=self.filename,
data=data
)

View file

@ -31,12 +31,13 @@ from datasette.utils import (
from datasette.filters import Filters from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg from .base import BaseView, DatasetteError, ureg
LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>' LINK_WITH_LABEL = (
'<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
)
LINK_WITH_VALUE = '<a href="/{database}/{table}/{link_id}">{id}</a>' LINK_WITH_VALUE = '<a href="/{database}/{table}/{link_id}">{id}</a>'
class RowTableShared(BaseView): class RowTableShared(BaseView):
async def sortable_columns_for_table(self, database, table, use_rowid): async def sortable_columns_for_table(self, database, table, use_rowid):
table_metadata = self.ds.table_metadata(database, table) table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata: if "sortable_columns" in table_metadata:
@ -51,18 +52,14 @@ class RowTableShared(BaseView):
# Returns list of (fk_dict, label_column-or-None) pairs for that table # Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = [] expandables = []
for fk in await self.ds.foreign_keys_for_table(database, table): for fk in await self.ds.foreign_keys_for_table(database, table):
label_column = await self.ds.label_column_for_table(database, fk["other_table"]) label_column = await self.ds.label_column_for_table(
database, fk["other_table"]
)
expandables.append((fk, label_column)) expandables.append((fk, label_column))
return expandables return expandables
async def display_columns_and_rows( async def display_columns_and_rows(
self, self, database, table, description, rows, link_column=False, truncate_cells=0
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
): ):
"Returns columns, rows for specified table - including fancy foreign key treatment" "Returns columns, rows for specified table - including fancy foreign key treatment"
table_metadata = self.ds.table_metadata(database, table) table_metadata = self.ds.table_metadata(database, table)
@ -121,8 +118,10 @@ class RowTableShared(BaseView):
if plugin_display_value is not None: if plugin_display_value is not None:
display_value = plugin_display_value display_value = plugin_display_value
elif isinstance(value, bytes): elif isinstance(value, bytes):
display_value = jinja2.Markup("&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format( display_value = jinja2.Markup(
len(value), "" if len(value) == 1 else "s") "&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format(
len(value), "" if len(value) == 1 else "s"
)
) )
elif isinstance(value, dict): elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row # It's an expanded foreign key - display link to other row
@ -133,13 +132,15 @@ class RowTableShared(BaseView):
link_template = ( link_template = (
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
) )
display_value = jinja2.Markup(link_template.format( display_value = jinja2.Markup(
link_template.format(
database=database, database=database,
table=urllib.parse.quote_plus(other_table), table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)), link_id=urllib.parse.quote_plus(str(value)),
id=str(jinja2.escape(value)), id=str(jinja2.escape(value)),
label=str(jinja2.escape(label)), label=str(jinja2.escape(label)),
)) )
)
elif value in ("", None): elif value in ("", None):
display_value = jinja2.Markup("&nbsp;") display_value = jinja2.Markup("&nbsp;")
elif is_url(str(value).strip()): elif is_url(str(value).strip()):
@ -180,9 +181,18 @@ class RowTableShared(BaseView):
class TableView(RowTableShared): class TableView(RowTableShared):
name = 'table' name = "table"
async def data(self, request, database, hash, table, default_labels=False, _next=None, _size=None): async def data(
self,
request,
database,
hash,
table,
default_labels=False,
_next=None,
_size=None,
):
canned_query = self.ds.get_canned_query(database, table) canned_query = self.ds.get_canned_query(database, table)
if canned_query is not None: if canned_query is not None:
return await self.custom_sql( return await self.custom_sql(
@ -271,12 +281,13 @@ class TableView(RowTableShared):
raise DatasetteError("_where= is not allowed", status=400) raise DatasetteError("_where= is not allowed", status=400)
else: else:
where_clauses.extend(request.args["_where"]) where_clauses.extend(request.args["_where"])
extra_wheres_for_ui = [{ extra_wheres_for_ui = [
{
"text": text, "text": text,
"remove_url": path_with_removed_args( "remove_url": path_with_removed_args(request, {"_where": text}),
request, {"_where": text} }
) for text in request.args["_where"]
} for text in request.args["_where"]] ]
# _search support: # _search support:
fts_table = special_args.get("_fts_table") fts_table = special_args.get("_fts_table")
@ -296,8 +307,7 @@ class TableView(RowTableShared):
search = search_args["_search"] search = search_args["_search"]
where_clauses.append( where_clauses.append(
"{fts_pk} in (select rowid from {fts_table} where {fts_table} match :search)".format( "{fts_pk} in (select rowid from {fts_table} where {fts_table} match :search)".format(
fts_table=escape_sqlite(fts_table), fts_table=escape_sqlite(fts_table), fts_pk=escape_sqlite(fts_pk)
fts_pk=escape_sqlite(fts_pk)
) )
) )
search_descriptions.append('search matches "{}"'.format(search)) search_descriptions.append('search matches "{}"'.format(search))
@ -306,14 +316,16 @@ class TableView(RowTableShared):
# More complex: search against specific columns # More complex: search against specific columns
for i, (key, search_text) in enumerate(search_args.items()): for i, (key, search_text) in enumerate(search_args.items()):
search_col = key.split("_search_", 1)[1] search_col = key.split("_search_", 1)[1]
if search_col not in await self.ds.table_columns(database, fts_table): if search_col not in await self.ds.table_columns(
database, fts_table
):
raise DatasetteError("Cannot search by that column", status=400) raise DatasetteError("Cannot search by that column", status=400)
where_clauses.append( where_clauses.append(
"rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format( "rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format(
fts_table=escape_sqlite(fts_table), fts_table=escape_sqlite(fts_table),
search_col=escape_sqlite(search_col), search_col=escape_sqlite(search_col),
i=i i=i,
) )
) )
search_descriptions.append( search_descriptions.append(
@ -325,7 +337,9 @@ class TableView(RowTableShared):
sortable_columns = set() sortable_columns = set()
sortable_columns = await self.sortable_columns_for_table(database, table, use_rowid) sortable_columns = await self.sortable_columns_for_table(
database, table, use_rowid
)
# Allow for custom sort order # Allow for custom sort order
sort = special_args.get("_sort") sort = special_args.get("_sort")
@ -346,9 +360,9 @@ class TableView(RowTableShared):
from_sql = "from {table_name} {where}".format( from_sql = "from {table_name} {where}".format(
table_name=escape_sqlite(table), table_name=escape_sqlite(table),
where=( where=("where {} ".format(" and ".join(where_clauses)))
"where {} ".format(" and ".join(where_clauses)) if where_clauses
) if where_clauses else "", else "",
) )
# Copy of params so we can mutate them later: # Copy of params so we can mutate them later:
from_sql_params = dict(**params) from_sql_params = dict(**params)
@ -410,7 +424,9 @@ class TableView(RowTableShared):
column=escape_sqlite(sort or sort_desc), column=escape_sqlite(sort or sort_desc),
op=">" if sort else "<", op=">" if sort else "<",
p=len(params), p=len(params),
extra_desc_only="" if sort else " or {column2} is null".format( extra_desc_only=""
if sort
else " or {column2} is null".format(
column2=escape_sqlite(sort or sort_desc) column2=escape_sqlite(sort or sort_desc)
), ),
next_clauses=" and ".join(next_by_pk_clauses), next_clauses=" and ".join(next_by_pk_clauses),
@ -470,9 +486,7 @@ class TableView(RowTableShared):
order_by=order_by, order_by=order_by,
) )
sql = "{sql_no_limit} limit {limit}{offset}".format( sql = "{sql_no_limit} limit {limit}{offset}".format(
sql_no_limit=sql_no_limit.rstrip(), sql_no_limit=sql_no_limit.rstrip(), limit=page_size + 1, offset=offset
limit=page_size + 1,
offset=offset,
) )
if request.raw_args.get("_timelimit"): if request.raw_args.get("_timelimit"):
@ -486,15 +500,17 @@ class TableView(RowTableShared):
filtered_table_rows_count = None filtered_table_rows_count = None
if count_sql: if count_sql:
try: try:
count_rows = list(await self.ds.execute( count_rows = list(
database, count_sql, from_sql_params await self.ds.execute(database, count_sql, from_sql_params)
)) )
filtered_table_rows_count = count_rows[0][0] filtered_table_rows_count = count_rows[0][0]
except InterruptedError: except InterruptedError:
pass pass
# facets support # facets support
if not self.ds.config("allow_facet") and any(arg.startswith("_facet") for arg in request.args): if not self.ds.config("allow_facet") and any(
arg.startswith("_facet") for arg in request.args
):
raise DatasetteError("_facet= is not allowed", status=400) raise DatasetteError("_facet= is not allowed", status=400)
# pylint: disable=no-member # pylint: disable=no-member
@ -505,7 +521,8 @@ class TableView(RowTableShared):
facets_timed_out = [] facets_timed_out = []
facet_instances = [] facet_instances = []
for klass in facet_classes: for klass in facet_classes:
facet_instances.append(klass( facet_instances.append(
klass(
self.ds, self.ds,
request, request,
database, database,
@ -514,10 +531,13 @@ class TableView(RowTableShared):
table=table, table=table,
metadata=table_metadata, metadata=table_metadata,
row_count=filtered_table_rows_count, row_count=filtered_table_rows_count,
)) )
)
for facet in facet_instances: for facet in facet_instances:
instance_facet_results, instance_facets_timed_out = await facet.facet_results() instance_facet_results, instance_facets_timed_out = (
await facet.facet_results()
)
facet_results.update(instance_facet_results) facet_results.update(instance_facet_results)
facets_timed_out.extend(instance_facets_timed_out) facets_timed_out.extend(instance_facets_timed_out)
@ -542,9 +562,7 @@ class TableView(RowTableShared):
columns_to_expand = request.args["_label"] columns_to_expand = request.args["_label"]
if columns_to_expand is None and all_labels: if columns_to_expand is None and all_labels:
# expand all columns with foreign keys # expand all columns with foreign keys
columns_to_expand = [ columns_to_expand = [fk["column"] for fk, _ in expandable_columns]
fk["column"] for fk, _ in expandable_columns
]
if columns_to_expand: if columns_to_expand:
expanded_labels = {} expanded_labels = {}
@ -557,9 +575,9 @@ class TableView(RowTableShared):
column_index = columns.index(column) column_index = columns.index(column)
values = [row[column_index] for row in rows] values = [row[column_index] for row in rows]
# Expand them # Expand them
expanded_labels.update(await self.ds.expand_foreign_keys( expanded_labels.update(
database, table, column, values await self.ds.expand_foreign_keys(database, table, column, values)
)) )
if expanded_labels: if expanded_labels:
# Rewrite the rows # Rewrite the rows
new_rows = [] new_rows = []
@ -569,8 +587,8 @@ class TableView(RowTableShared):
value = row[column] value = row[column]
if (column, value) in expanded_labels: if (column, value) in expanded_labels:
new_row[column] = { new_row[column] = {
'value': value, "value": value,
'label': expanded_labels[(column, value)] "label": expanded_labels[(column, value)],
} }
else: else:
new_row[column] = value new_row[column] = value
@ -608,7 +626,11 @@ class TableView(RowTableShared):
# Detect suggested facets # Detect suggested facets
suggested_facets = [] suggested_facets = []
if self.ds.config("suggest_facets") and self.ds.config("allow_facet") and not _next: if (
self.ds.config("suggest_facets")
and self.ds.config("allow_facet")
and not _next
):
for facet in facet_instances: for facet in facet_instances:
# TODO: ensure facet is not suggested if it is already active # TODO: ensure facet is not suggested if it is already active
# used to use 'if facet_column in facets' for this # used to use 'if facet_column in facets' for this
@ -634,10 +656,11 @@ class TableView(RowTableShared):
link_column=not is_view, link_column=not is_view,
truncate_cells=self.ds.config("truncate_cells_html"), truncate_cells=self.ds.config("truncate_cells_html"),
) )
metadata = (self.ds.metadata("databases") or {}).get(database, {}).get( metadata = (
"tables", {} (self.ds.metadata("databases") or {})
).get( .get(database, {})
table, {} .get("tables", {})
.get(table, {})
) )
self.ds.update_with_inherited_metadata(metadata) self.ds.update_with_inherited_metadata(metadata)
form_hidden_args = [] form_hidden_args = []
@ -656,7 +679,7 @@ class TableView(RowTableShared):
"sorted_facet_results": sorted( "sorted_facet_results": sorted(
facet_results.values(), facet_results.values(),
key=lambda f: (len(f["results"]), f["name"]), key=lambda f: (len(f["results"]), f["name"]),
reverse=True reverse=True,
), ),
"extra_wheres_for_ui": extra_wheres_for_ui, "extra_wheres_for_ui": extra_wheres_for_ui,
"form_hidden_args": form_hidden_args, "form_hidden_args": form_hidden_args,
@ -682,7 +705,8 @@ class TableView(RowTableShared):
"table_definition": await self.ds.get_table_definition(database, table), "table_definition": await self.ds.get_table_definition(database, table),
} }
return { return (
{
"database": database, "database": database,
"table": table, "table": table,
"is_view": is_view, "is_view": is_view,
@ -700,14 +724,17 @@ class TableView(RowTableShared):
"suggested_facets": suggested_facets, "suggested_facets": suggested_facets,
"next": next_value and str(next_value) or None, "next": next_value and str(next_value) or None,
"next_url": next_url, "next_url": next_url,
}, extra_template, ( },
extra_template,
(
"table-{}-{}.html".format(to_css_class(database), to_css_class(table)), "table-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"table.html", "table.html",
),
) )
class RowView(RowTableShared): class RowView(RowTableShared):
name = 'row' name = "row"
async def data(self, request, database, hash, table, pk_path, default_labels=False): async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path) pk_values = urlsafe_components(pk_path)
@ -720,15 +747,13 @@ class RowView(RowTableShared):
select = "rowid, *" select = "rowid, *"
pks = ["rowid"] pks = ["rowid"]
wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)] wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)]
sql = 'select {} from {} where {}'.format( sql = "select {} from {} where {}".format(
select, escape_sqlite(table), " AND ".join(wheres) select, escape_sqlite(table), " AND ".join(wheres)
) )
params = {} params = {}
for i, pk_value in enumerate(pk_values): for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value params["p{}".format(i)] = pk_value
results = await self.ds.execute( results = await self.ds.execute(database, sql, params, truncate=True)
database, sql, params, truncate=True
)
columns = [r[0] for r in results.description] columns = [r[0] for r in results.description]
rows = list(results.rows) rows = list(results.rows)
if not rows: if not rows:
@ -760,13 +785,10 @@ class RowView(RowTableShared):
), ),
"_rows_and_columns.html", "_rows_and_columns.html",
], ],
"metadata": ( "metadata": (self.ds.metadata("databases") or {})
self.ds.metadata("databases") or {} .get(database, {})
).get(database, {}).get( .get("tables", {})
"tables", {} .get(table, {}),
).get(
table, {}
),
} }
data = { data = {
@ -784,8 +806,13 @@ class RowView(RowTableShared):
database, table, pk_values database, table, pk_values
) )
return data, template_data, ( return (
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)), "row.html" data,
template_data,
(
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"row.html",
),
) )
async def foreign_key_tables(self, database, table, pk_values): async def foreign_key_tables(self, database, table, pk_values):
@ -801,7 +828,7 @@ class RowView(RowTableShared):
sql = "select " + ", ".join( sql = "select " + ", ".join(
[ [
'(select count(*) from {table} where {column}=:id)'.format( "(select count(*) from {table} where {column}=:id)".format(
table=escape_sqlite(fk["other_table"]), table=escape_sqlite(fk["other_table"]),
column=escape_sqlite(fk["other_column"]), column=escape_sqlite(fk["other_column"]),
) )
@ -822,8 +849,8 @@ class RowView(RowTableShared):
) )
foreign_key_tables = [] foreign_key_tables = []
for fk in foreign_keys: for fk in foreign_keys:
count = foreign_table_counts.get( count = (
(fk["other_table"], fk["other_column"]) foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0
) or 0 )
foreign_key_tables.append({**fk, **{"count": count}}) foreign_key_tables.append({**fk, **{"count": count}})
return foreign_key_tables return foreign_key_tables

View file

@ -1,72 +1,78 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
import os import os
import sys
import versioneer import versioneer
def get_long_description(): def get_long_description():
with open(os.path.join( with open(
os.path.dirname(os.path.abspath(__file__)), 'README.md' os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
), encoding='utf8') as fp: encoding="utf8",
) as fp:
return fp.read() return fp.read()
def get_version(): def get_version():
path = os.path.join( path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'datasette', 'version.py' os.path.dirname(os.path.abspath(__file__)), "datasette", "version.py"
) )
g = {} g = {}
exec(open(path).read(), g) exec(open(path).read(), g)
return g['__version__'] return g["__version__"]
# Only install black on Python 3.6 or higher
maybe_black = []
if sys.version_info > (3, 6):
maybe_black = ["black"]
setup( setup(
name='datasette', name="datasette",
version=versioneer.get_version(), version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(), cmdclass=versioneer.get_cmdclass(),
description='An instant JSON API for your SQLite databases', description="An instant JSON API for your SQLite databases",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
author='Simon Willison', author="Simon Willison",
license='Apache License, Version 2.0', license="Apache License, Version 2.0",
url='https://github.com/simonw/datasette', url="https://github.com/simonw/datasette",
packages=find_packages(), packages=find_packages(),
package_data={'datasette': ['templates/*.html']}, package_data={"datasette": ["templates/*.html"]},
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'click>=6.7', "click>=6.7",
'click-default-group==1.2', "click-default-group==1.2",
'Sanic==0.7.0', "Sanic==0.7.0",
'Jinja2==2.10.1', "Jinja2==2.10.1",
'hupper==1.0', "hupper==1.0",
'pint==0.8.1', "pint==0.8.1",
'pluggy>=0.7.1', "pluggy>=0.7.1",
], ],
entry_points=''' entry_points="""
[console_scripts] [console_scripts]
datasette=datasette.cli:cli datasette=datasette.cli:cli
''', """,
setup_requires=['pytest-runner'], setup_requires=["pytest-runner"],
extras_require={ extras_require={
'test': [ "test": [
'pytest==4.0.2', "pytest==4.0.2",
'pytest-asyncio==0.10.0', "pytest-asyncio==0.10.0",
'aiohttp==3.5.3', "aiohttp==3.5.3",
'beautifulsoup4==4.6.1', "beautifulsoup4==4.6.1",
] ]
+ maybe_black
}, },
tests_require=[ tests_require=["datasette[test]"],
'datasette[test]',
],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'Intended Audience :: Science/Research', "Intended Audience :: Science/Research",
'Intended Audience :: End Users/Desktop', "Intended Audience :: End Users/Desktop",
'Topic :: Database', "Topic :: Database",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
], ],
) )

View file

@ -8,3 +8,10 @@ def pytest_unconfigure(config):
import sys import sys
del sys._called_from_test del sys._called_from_test
def pytest_collection_modifyitems(items):
# Ensure test_black.py runs first before any asyncio code kicks in
test_black = [fn for fn in items if fn.name == "test_black"]
if test_black:
items.insert(0, items.pop(items.index(test_black[0])))

View file

@ -17,9 +17,7 @@ class TestClient:
def get(self, path, allow_redirects=True): def get(self, path, allow_redirects=True):
return self.sanic_test_client.get( return self.sanic_test_client.get(
path, path, allow_redirects=allow_redirects, gather_request=False
allow_redirects=allow_redirects,
gather_request=False
) )
@ -79,39 +77,35 @@ def app_client_no_files():
client.ds = ds client.ds = ds
yield client yield client
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def app_client_with_memory(): def app_client_with_memory():
yield from make_app_client(memory=True) yield from make_app_client(memory=True)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def app_client_with_hash(): def app_client_with_hash():
yield from make_app_client(config={ yield from make_app_client(config={"hash_urls": True}, is_immutable=True)
'hash_urls': True,
}, is_immutable=True)
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def app_client_shorter_time_limit(): def app_client_shorter_time_limit():
yield from make_app_client(20) yield from make_app_client(20)
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def app_client_returned_rows_matches_page_size(): def app_client_returned_rows_matches_page_size():
yield from make_app_client(max_returned_rows=50) yield from make_app_client(max_returned_rows=50)
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def app_client_larger_cache_size(): def app_client_larger_cache_size():
yield from make_app_client(config={ yield from make_app_client(config={"cache_size_kb": 2500})
'cache_size_kb': 2500,
})
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def app_client_csv_max_mb_one(): def app_client_csv_max_mb_one():
yield from make_app_client(config={ yield from make_app_client(config={"max_csv_mb": 1})
'max_csv_mb': 1,
})
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -119,7 +113,7 @@ def app_client_with_dot():
yield from make_app_client(filename="fixtures.dot.db") yield from make_app_client(filename="fixtures.dot.db")
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def app_client_with_cors(): def app_client_with_cors():
yield from make_app_client(cors=True) yield from make_app_client(cors=True)
@ -128,7 +122,7 @@ def generate_compound_rows(num):
for a, b, c in itertools.islice( for a, b, c in itertools.islice(
itertools.product(string.ascii_lowercase, repeat=3), num itertools.product(string.ascii_lowercase, repeat=3), num
): ):
yield a, b, c, '{}-{}-{}'.format(a, b, c) yield a, b, c, "{}-{}-{}".format(a, b, c)
def generate_sortable_rows(num): def generate_sortable_rows(num):
@ -137,107 +131,81 @@ def generate_sortable_rows(num):
itertools.product(string.ascii_lowercase, repeat=2), num itertools.product(string.ascii_lowercase, repeat=2), num
): ):
yield { yield {
'pk1': a, "pk1": a,
'pk2': b, "pk2": b,
'content': '{}-{}'.format(a, b), "content": "{}-{}".format(a, b),
'sortable': rand.randint(-100, 100), "sortable": rand.randint(-100, 100),
'sortable_with_nulls': rand.choice([ "sortable_with_nulls": rand.choice([None, rand.random(), rand.random()]),
None, rand.random(), rand.random() "sortable_with_nulls_2": rand.choice([None, rand.random(), rand.random()]),
]), "text": rand.choice(["$null", "$blah"]),
'sortable_with_nulls_2': rand.choice([
None, rand.random(), rand.random()
]),
'text': rand.choice(['$null', '$blah']),
} }
METADATA = { METADATA = {
'title': 'Datasette Fixtures', "title": "Datasette Fixtures",
'description': 'An example SQLite database demonstrating Datasette', "description": "An example SQLite database demonstrating Datasette",
'license': 'Apache License 2.0', "license": "Apache License 2.0",
'license_url': 'https://github.com/simonw/datasette/blob/master/LICENSE', "license_url": "https://github.com/simonw/datasette/blob/master/LICENSE",
'source': 'tests/fixtures.py', "source": "tests/fixtures.py",
'source_url': 'https://github.com/simonw/datasette/blob/master/tests/fixtures.py', "source_url": "https://github.com/simonw/datasette/blob/master/tests/fixtures.py",
'about': 'About Datasette', "about": "About Datasette",
'about_url': 'https://github.com/simonw/datasette', "about_url": "https://github.com/simonw/datasette",
"plugins": { "plugins": {"name-of-plugin": {"depth": "root"}},
"name-of-plugin": { "databases": {
"depth": "root" "fixtures": {
} "description": "Test tables description",
}, "plugins": {"name-of-plugin": {"depth": "database"}},
'databases': { "tables": {
'fixtures': { "simple_primary_key": {
'description': 'Test tables description', "description_html": "Simple <em>primary</em> key",
"plugins": { "title": "This <em>HTML</em> is escaped",
"name-of-plugin": {
"depth": "database"
}
},
'tables': {
'simple_primary_key': {
'description_html': 'Simple <em>primary</em> key',
'title': 'This <em>HTML</em> is escaped',
"plugins": { "plugins": {
"name-of-plugin": { "name-of-plugin": {
"depth": "table", "depth": "table",
"special": "this-is-simple_primary_key" "special": "this-is-simple_primary_key",
}
} }
}, },
'sortable': { },
'sortable_columns': [ "sortable": {
'sortable', "sortable_columns": [
'sortable_with_nulls', "sortable",
'sortable_with_nulls_2', "sortable_with_nulls",
'text', "sortable_with_nulls_2",
"text",
], ],
"plugins": { "plugins": {"name-of-plugin": {"depth": "table"}},
"name-of-plugin": {
"depth": "table"
}
}
}, },
'no_primary_key': { "no_primary_key": {"sortable_columns": [], "hidden": True},
'sortable_columns': [], "units": {"units": {"distance": "m", "frequency": "Hz"}},
'hidden': True, "primary_key_multiple_columns_explicit_label": {
"label_column": "content2"
}, },
'units': { "simple_view": {"sortable_columns": ["content"]},
'units': { "searchable_view_configured_by_metadata": {
'distance': 'm', "fts_table": "searchable_fts",
'frequency': 'Hz' "fts_pk": "pk",
}
}, },
'primary_key_multiple_columns_explicit_label': {
'label_column': 'content2',
}, },
'simple_view': { "queries": {
'sortable_columns': ['content'], "pragma_cache_size": "PRAGMA cache_size;",
}, "neighborhood_search": {
'searchable_view_configured_by_metadata': { "sql": """
'fts_table': 'searchable_fts',
'fts_pk': 'pk'
}
},
'queries': {
'pragma_cache_size': 'PRAGMA cache_size;',
'neighborhood_search': {
'sql': '''
select neighborhood, facet_cities.name, state select neighborhood, facet_cities.name, state
from facetable from facetable
join facet_cities join facet_cities
on facetable.city_id = facet_cities.id on facetable.city_id = facet_cities.id
where neighborhood like '%' || :text || '%' where neighborhood like '%' || :text || '%'
order by neighborhood; order by neighborhood;
''', """,
'title': 'Search neighborhoods', "title": "Search neighborhoods",
'description_html': '<b>Demonstrating</b> simple like search', "description_html": "<b>Demonstrating</b> simple like search",
},
}, },
} }
}, },
} }
}
PLUGIN1 = ''' PLUGIN1 = """
from datasette import hookimpl from datasette import hookimpl
import base64 import base64
import pint import pint
@ -304,9 +272,9 @@ def render_cell(value, column, table, database, datasette):
table=table, table=table,
) )
}) })
''' """
PLUGIN2 = ''' PLUGIN2 = """
from datasette import hookimpl from datasette import hookimpl
import jinja2 import jinja2
import json import json
@ -349,9 +317,10 @@ def render_cell(value, database):
label=jinja2.escape(data["label"] or "") or "&nbsp;" label=jinja2.escape(data["label"] or "") or "&nbsp;"
) )
) )
''' """
TABLES = ''' TABLES = (
"""
CREATE TABLE simple_primary_key ( CREATE TABLE simple_primary_key (
id varchar(30) primary key, id varchar(30) primary key,
content text content text
@ -581,26 +550,42 @@ CREATE VIEW searchable_view AS
CREATE VIEW searchable_view_configured_by_metadata AS CREATE VIEW searchable_view_configured_by_metadata AS
SELECT * from searchable; SELECT * from searchable;
''' + '\n'.join([ """
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(i=i + 1) + "\n".join(
[
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(
i=i + 1
)
for i in range(201) for i in range(201)
]) + '\n'.join([ ]
)
+ "\n".join(
[
'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format( 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format(
a=a, b=b, c=c, content=content a=a, b=b, c=c, content=content
) for a, b, c, content in generate_compound_rows(1001) )
]) + '\n'.join([ for a, b, c, content in generate_compound_rows(1001)
'''INSERT INTO sortable VALUES ( ]
)
+ "\n".join(
[
"""INSERT INTO sortable VALUES (
"{pk1}", "{pk2}", "{content}", {sortable}, "{pk1}", "{pk2}", "{content}", {sortable},
{sortable_with_nulls}, {sortable_with_nulls_2}, "{text}"); {sortable_with_nulls}, {sortable_with_nulls_2}, "{text}");
'''.format( """.format(
**row **row
).replace('None', 'null') for row in generate_sortable_rows(201) ).replace(
]) "None", "null"
TABLE_PARAMETERIZED_SQL = [( )
"insert into binary_data (data) values (?);", [b'this is binary data'] for row in generate_sortable_rows(201)
)] ]
)
)
TABLE_PARAMETERIZED_SQL = [
("insert into binary_data (data) values (?);", [b"this is binary data"])
]
if __name__ == '__main__': if __name__ == "__main__":
# Can be called with data.db OR data.db metadata.json # Can be called with data.db OR data.db metadata.json
db_filename = sys.argv[-1] db_filename = sys.argv[-1]
metadata_filename = None metadata_filename = None
@ -615,9 +600,7 @@ if __name__ == '__main__':
conn.execute(sql, params) conn.execute(sql, params)
print("Test tables written to {}".format(db_filename)) print("Test tables written to {}".format(db_filename))
if metadata_filename: if metadata_filename:
open(metadata_filename, 'w').write(json.dumps(METADATA)) open(metadata_filename, "w").write(json.dumps(METADATA))
print("- metadata written to {}".format(metadata_filename)) print("- metadata written to {}".format(metadata_filename))
else: else:
print("Usage: {} db_to_write.db [metadata_to_write.json]".format( print("Usage: {} db_to_write.db [metadata_to_write.json]".format(sys.argv[0]))
sys.argv[0]
))

File diff suppressed because it is too large Load diff

20
tests/test_black.py Normal file
View file

@ -0,0 +1,20 @@
from click.testing import CliRunner
from pathlib import Path
import pytest
import sys
code_root = Path(__file__).parent.parent
@pytest.mark.skipif(
sys.version_info[:2] < (3, 6), reason="Black requires Python 3.6 or later"
)
def test_black():
# Do not import at top of module because Python 3.5 will not have it installed
import black
runner = CliRunner()
result = runner.invoke(
black.main, [str(code_root / "tests"), str(code_root / "datasette"), "--check"]
)
assert result.exit_code == 0, result.output

View file

@ -1,22 +1,26 @@
from .fixtures import ( # noqa from .fixtures import ( # noqa
app_client, app_client,
app_client_csv_max_mb_one, app_client_csv_max_mb_one,
app_client_with_cors app_client_with_cors,
) )
EXPECTED_TABLE_CSV = '''id,content EXPECTED_TABLE_CSV = """id,content
1,hello 1,hello
2,world 2,world
3, 3,
4,RENDER_CELL_DEMO 4,RENDER_CELL_DEMO
'''.replace('\n', '\r\n') """.replace(
"\n", "\r\n"
)
EXPECTED_CUSTOM_CSV = '''content EXPECTED_CUSTOM_CSV = """content
hello hello
world world
'''.replace('\n', '\r\n') """.replace(
"\n", "\r\n"
)
EXPECTED_TABLE_WITH_LABELS_CSV = ''' EXPECTED_TABLE_WITH_LABELS_CSV = """
pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
1,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]" 1,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]"
2,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]" 2,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]"
@ -33,45 +37,47 @@ pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
13,1,1,MI,3,Detroit,Corktown,[] 13,1,1,MI,3,Detroit,Corktown,[]
14,1,1,MI,3,Detroit,Mexicantown,[] 14,1,1,MI,3,Detroit,Mexicantown,[]
15,2,0,MC,4,Memnonia,Arcadia Planitia,[] 15,2,0,MC,4,Memnonia,Arcadia Planitia,[]
'''.lstrip().replace('\n', '\r\n') """.lstrip().replace(
"\n", "\r\n"
)
def test_table_csv(app_client): def test_table_csv(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv') response = app_client.get("/fixtures/simple_primary_key.csv")
assert response.status == 200 assert response.status == 200
assert not response.headers.get("Access-Control-Allow-Origin") assert not response.headers.get("Access-Control-Allow-Origin")
assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_CSV == response.text assert EXPECTED_TABLE_CSV == response.text
def test_table_csv_cors_headers(app_client_with_cors): def test_table_csv_cors_headers(app_client_with_cors):
response = app_client_with_cors.get('/fixtures/simple_primary_key.csv') response = app_client_with_cors.get("/fixtures/simple_primary_key.csv")
assert response.status == 200 assert response.status == 200
assert "*" == response.headers["Access-Control-Allow-Origin"] assert "*" == response.headers["Access-Control-Allow-Origin"]
def test_table_csv_with_labels(app_client): def test_table_csv_with_labels(app_client):
response = app_client.get('/fixtures/facetable.csv?_labels=1') response = app_client.get("/fixtures/facetable.csv?_labels=1")
assert response.status == 200 assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
def test_custom_sql_csv(app_client): def test_custom_sql_csv(app_client):
response = app_client.get( response = app_client.get(
'/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2' "/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
) )
assert response.status == 200 assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_CUSTOM_CSV == response.text assert EXPECTED_CUSTOM_CSV == response.text
def test_table_csv_download(app_client): def test_table_csv_download(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv?_dl=1') response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
assert response.status == 200 assert response.status == 200
assert 'text/csv; charset=utf-8' == response.headers['Content-Type'] assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
expected_disposition = 'attachment; filename="simple_primary_key.csv"' expected_disposition = 'attachment; filename="simple_primary_key.csv"'
assert expected_disposition == response.headers['Content-Disposition'] assert expected_disposition == response.headers["Content-Disposition"]
def test_max_csv_mb(app_client_csv_max_mb_one): def test_max_csv_mb(app_client_csv_max_mb_one):
@ -88,12 +94,8 @@ def test_max_csv_mb(app_client_csv_max_mb_one):
def test_table_csv_stream(app_client): def test_table_csv_stream(app_client):
# Without _stream should return header + 100 rows: # Without _stream should return header + 100 rows:
response = app_client.get( response = app_client.get("/fixtures/compound_three_primary_keys.csv?_size=max")
"/fixtures/compound_three_primary_keys.csv?_size=max"
)
assert 101 == len([b for b in response.body.split(b"\r\n") if b]) assert 101 == len([b for b in response.body.split(b"\r\n") if b])
# With _stream=1 should return header + 1001 rows # With _stream=1 should return header + 1001 rows
response = app_client.get( response = app_client.get("/fixtures/compound_three_primary_keys.csv?_stream=1")
"/fixtures/compound_three_primary_keys.csv?_stream=1"
)
assert 1002 == len([b for b in response.body.split(b"\r\n") if b]) assert 1002 == len([b for b in response.body.split(b"\r\n") if b])

View file

@ -9,13 +9,13 @@ from pathlib import Path
import pytest import pytest
import re import re
docs_path = Path(__file__).parent.parent / 'docs' docs_path = Path(__file__).parent.parent / "docs"
label_re = re.compile(r'\.\. _([^\s:]+):') label_re = re.compile(r"\.\. _([^\s:]+):")
def get_headings(filename, underline="-"): def get_headings(filename, underline="-"):
content = (docs_path / filename).open().read() content = (docs_path / filename).open().read()
heading_re = re.compile(r'(\w+)(\([^)]*\))?\n\{}+\n'.format(underline)) heading_re = re.compile(r"(\w+)(\([^)]*\))?\n\{}+\n".format(underline))
return set(h[0] for h in heading_re.findall(content)) return set(h[0] for h in heading_re.findall(content))
@ -24,38 +24,37 @@ def get_labels(filename):
return set(label_re.findall(content)) return set(label_re.findall(content))
@pytest.mark.parametrize('config', app.CONFIG_OPTIONS) @pytest.mark.parametrize("config", app.CONFIG_OPTIONS)
def test_config_options_are_documented(config): def test_config_options_are_documented(config):
assert config.name in get_headings("config.rst") assert config.name in get_headings("config.rst")
@pytest.mark.parametrize("name,filename", ( @pytest.mark.parametrize(
"name,filename",
(
("serve", "datasette-serve-help.txt"), ("serve", "datasette-serve-help.txt"),
("package", "datasette-package-help.txt"), ("package", "datasette-package-help.txt"),
("publish now", "datasette-publish-now-help.txt"), ("publish now", "datasette-publish-now-help.txt"),
("publish heroku", "datasette-publish-heroku-help.txt"), ("publish heroku", "datasette-publish-heroku-help.txt"),
("publish cloudrun", "datasette-publish-cloudrun-help.txt"), ("publish cloudrun", "datasette-publish-cloudrun-help.txt"),
)) ),
)
def test_help_includes(name, filename): def test_help_includes(name, filename):
expected = open(str(docs_path / filename)).read() expected = open(str(docs_path / filename)).read()
runner = CliRunner() runner = CliRunner()
result = runner.invoke(cli, name.split() + ["--help"], terminal_width=88) result = runner.invoke(cli, name.split() + ["--help"], terminal_width=88)
actual = "$ datasette {} --help\n\n{}".format( actual = "$ datasette {} --help\n\n{}".format(name, result.output)
name, result.output
)
# actual has "Usage: cli package [OPTIONS] FILES" # actual has "Usage: cli package [OPTIONS] FILES"
# because it doesn't know that cli will be aliased to datasette # because it doesn't know that cli will be aliased to datasette
expected = expected.replace("Usage: datasette", "Usage: cli") expected = expected.replace("Usage: datasette", "Usage: cli")
assert expected == actual assert expected == actual
@pytest.mark.parametrize('plugin', [ @pytest.mark.parametrize(
name for name in dir(app.pm.hook) if not name.startswith('_') "plugin", [name for name in dir(app.pm.hook) if not name.startswith("_")]
]) )
def test_plugin_hooks_are_documented(plugin): def test_plugin_hooks_are_documented(plugin):
headings = [ headings = [s.split("(")[0] for s in get_headings("plugins.rst", "~")]
s.split("(")[0] for s in get_headings("plugins.rst", "~")
]
assert plugin in headings assert plugin in headings

View file

@ -2,102 +2,57 @@ from datasette.filters import Filters
import pytest import pytest
@pytest.mark.parametrize('args,expected_where,expected_params', [ @pytest.mark.parametrize(
"args,expected_where,expected_params",
[
((("name_english__contains", "foo"),), ['"name_english" like :p0'], ["%foo%"]),
( (
( (("foo", "bar"), ("bar__contains", "baz")),
('name_english__contains', 'foo'),
),
['"name_english" like :p0'],
['%foo%']
),
(
(
('foo', 'bar'),
('bar__contains', 'baz'),
),
['"bar" like :p0', '"foo" = :p1'], ['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar'] ["%baz%", "bar"],
), ),
( (
( (("foo__startswith", "bar"), ("bar__endswith", "baz")),
('foo__startswith', 'bar'),
('bar__endswith', 'baz'),
),
['"bar" like :p0', '"foo" like :p1'], ['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%'] ["%baz", "bar%"],
), ),
( (
( (("foo__lt", "1"), ("bar__gt", "2"), ("baz__gte", "3"), ("bax__lte", "4")),
('foo__lt', '1'),
('bar__gt', '2'),
('baz__gte', '3'),
('bax__lte', '4'),
),
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'], ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1] [2, 4, 3, 1],
), ),
( (
( (("foo__like", "2%2"), ("zax__glob", "3*")),
('foo__like', '2%2'),
('zax__glob', '3*'),
),
['"foo" like :p0', '"zax" glob :p1'], ['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*'] ["2%2", "3*"],
), ),
# Multiple like arguments: # Multiple like arguments:
( (
( (("foo__like", "2%2"), ("foo__like", "3%3")),
('foo__like', '2%2'),
('foo__like', '3%3'),
),
['"foo" like :p0', '"foo" like :p1'], ['"foo" like :p0', '"foo" like :p1'],
['2%2', '3%3'] ["2%2", "3%3"],
), ),
( (
( (("foo__isnull", "1"), ("baz__isnull", "1"), ("bar__gt", "10")),
('foo__isnull', '1'),
('baz__isnull', '1'),
('bar__gt', '10'),
),
['"bar" > :p0', '"baz" is null', '"foo" is null'], ['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10] [10],
),
(
(
('foo__in', '1,2,3'),
),
['foo in (:p0, :p1, :p2)'],
["1", "2", "3"]
), ),
((("foo__in", "1,2,3"),), ["foo in (:p0, :p1, :p2)"], ["1", "2", "3"]),
# date # date
( ((("foo__date", "1988-01-01"),), ["date(foo) = :p0"], ["1988-01-01"]),
(
("foo__date", "1988-01-01"),
),
["date(foo) = :p0"],
["1988-01-01"]
),
# JSON array variants of __in (useful for unexpected characters) # JSON array variants of __in (useful for unexpected characters)
((("foo__in", "[1,2,3]"),), ["foo in (:p0, :p1, :p2)"], [1, 2, 3]),
( (
( (("foo__in", '["dog,cat", "cat[dog]"]'),),
('foo__in', '[1,2,3]'), ["foo in (:p0, :p1)"],
["dog,cat", "cat[dog]"],
), ),
['foo in (:p0, :p1, :p2)'], ],
[1, 2, 3] )
),
(
(
('foo__in', '["dog,cat", "cat[dog]"]'),
),
['foo in (:p0, :p1)'],
["dog,cat", "cat[dog]"]
),
])
def test_build_where(args, expected_where, expected_params): def test_build_where(args, expected_where, expected_params):
f = Filters(sorted(args)) f = Filters(sorted(args))
sql_bits, actual_params = f.build_where_clauses("table") sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits assert expected_where == sql_bits
assert { assert {
'p{}'.format(i): param "p{}".format(i): param for i, param in enumerate(expected_params)
for i, param in enumerate(expected_params)
} == actual_params } == actual_params

File diff suppressed because it is too large Load diff

View file

@ -5,7 +5,7 @@ import pytest
import tempfile import tempfile
TABLES = ''' TABLES = """
CREATE TABLE "election_results" ( CREATE TABLE "election_results" (
"county" INTEGER, "county" INTEGER,
"party" INTEGER, "party" INTEGER,
@ -32,13 +32,13 @@ CREATE TABLE "office" (
"id" INTEGER PRIMARY KEY , "id" INTEGER PRIMARY KEY ,
"name" TEXT "name" TEXT
); );
''' """
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def ds_instance(): def ds_instance():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, 'fixtures.db') filepath = os.path.join(tmpdir, "fixtures.db")
conn = sqlite3.connect(filepath) conn = sqlite3.connect(filepath)
conn.executescript(TABLES) conn.executescript(TABLES)
yield Datasette([filepath]) yield Datasette([filepath])
@ -46,58 +46,47 @@ def ds_instance():
def test_inspect_hidden_tables(ds_instance): def test_inspect_hidden_tables(ds_instance):
info = ds_instance.inspect() info = ds_instance.inspect()
tables = info['fixtures']['tables'] tables = info["fixtures"]["tables"]
expected_hidden = ( expected_hidden = (
'election_results_fts', "election_results_fts",
'election_results_fts_content', "election_results_fts_content",
'election_results_fts_docsize', "election_results_fts_docsize",
'election_results_fts_segdir', "election_results_fts_segdir",
'election_results_fts_segments', "election_results_fts_segments",
'election_results_fts_stat', "election_results_fts_stat",
)
expected_visible = (
'election_results',
'county',
'party',
'office',
) )
expected_visible = ("election_results", "county", "party", "office")
assert sorted(expected_hidden) == sorted( assert sorted(expected_hidden) == sorted(
[table for table in tables if tables[table]['hidden']] [table for table in tables if tables[table]["hidden"]]
) )
assert sorted(expected_visible) == sorted( assert sorted(expected_visible) == sorted(
[table for table in tables if not tables[table]['hidden']] [table for table in tables if not tables[table]["hidden"]]
) )
def test_inspect_foreign_keys(ds_instance): def test_inspect_foreign_keys(ds_instance):
info = ds_instance.inspect() info = ds_instance.inspect()
tables = info['fixtures']['tables'] tables = info["fixtures"]["tables"]
for table_name in ('county', 'party', 'office'): for table_name in ("county", "party", "office"):
assert 0 == tables[table_name]['count'] assert 0 == tables[table_name]["count"]
foreign_keys = tables[table_name]['foreign_keys'] foreign_keys = tables[table_name]["foreign_keys"]
assert [] == foreign_keys['outgoing'] assert [] == foreign_keys["outgoing"]
assert [{ assert [
'column': 'id', {
'other_column': table_name, "column": "id",
'other_table': 'election_results' "other_column": table_name,
}] == foreign_keys['incoming'] "other_table": "election_results",
}
] == foreign_keys["incoming"]
election_results = tables['election_results'] election_results = tables["election_results"]
assert 0 == election_results['count'] assert 0 == election_results["count"]
assert sorted([{ assert sorted(
'column': 'county', [
'other_column': 'id', {"column": "county", "other_column": "id", "other_table": "county"},
'other_table': 'county' {"column": "party", "other_column": "id", "other_table": "party"},
}, { {"column": "office", "other_column": "id", "other_table": "office"},
'column': 'party', ],
'other_column': 'id', key=lambda d: d["column"],
'other_table': 'party' ) == sorted(election_results["foreign_keys"]["outgoing"], key=lambda d: d["column"])
}, { assert [] == election_results["foreign_keys"]["incoming"]
'column': 'office',
'other_column': 'id',
'other_table': 'office'
}], key=lambda d: d['column']) == sorted(
election_results['foreign_keys']['outgoing'],
key=lambda d: d['column']
)
assert [] == election_results['foreign_keys']['incoming']

View file

@ -1,7 +1,5 @@
from bs4 import BeautifulSoup as Soup from bs4 import BeautifulSoup as Soup
from .fixtures import ( # noqa from .fixtures import app_client # noqa
app_client,
)
import base64 import base64
import json import json
import re import re
@ -13,41 +11,26 @@ def test_plugins_dir_plugin(app_client):
response = app_client.get( response = app_client.get(
"/fixtures.json?sql=select+convert_units(100%2C+'m'%2C+'ft')" "/fixtures.json?sql=select+convert_units(100%2C+'m'%2C+'ft')"
) )
assert pytest.approx(328.0839) == response.json['rows'][0][0] assert pytest.approx(328.0839) == response.json["rows"][0][0]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"path,expected_decoded_object", "path,expected_decoded_object",
[ [
( ("/", {"template": "index.html", "database": None, "table": None}),
"/",
{
"template": "index.html",
"database": None,
"table": None,
},
),
( (
"/fixtures/", "/fixtures/",
{ {"template": "database.html", "database": "fixtures", "table": None},
"template": "database.html",
"database": "fixtures",
"table": None,
},
), ),
( (
"/fixtures/sortable", "/fixtures/sortable",
{ {"template": "table.html", "database": "fixtures", "table": "sortable"},
"template": "table.html",
"database": "fixtures",
"table": "sortable",
},
), ),
], ],
) )
def test_plugin_extra_css_urls(app_client, path, expected_decoded_object): def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
response = app_client.get(path) response = app_client.get(path)
links = Soup(response.body, 'html.parser').findAll('link') links = Soup(response.body, "html.parser").findAll("link")
special_href = [ special_href = [
l for l in links if l.attrs["href"].endswith("/extra-css-urls-demo.css") l for l in links if l.attrs["href"].endswith("/extra-css-urls-demo.css")
][0]["href"] ][0]["href"]
@ -59,47 +42,43 @@ def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
def test_plugin_extra_js_urls(app_client): def test_plugin_extra_js_urls(app_client):
response = app_client.get('/') response = app_client.get("/")
scripts = Soup(response.body, 'html.parser').findAll('script') scripts = Soup(response.body, "html.parser").findAll("script")
assert [ assert [
s for s in scripts s
if s.attrs == { for s in scripts
'integrity': 'SRIHASH', if s.attrs
'crossorigin': 'anonymous', == {
'src': 'https://example.com/jquery.js' "integrity": "SRIHASH",
"crossorigin": "anonymous",
"src": "https://example.com/jquery.js",
} }
] ]
def test_plugins_with_duplicate_js_urls(app_client): def test_plugins_with_duplicate_js_urls(app_client):
# If two plugins both require jQuery, jQuery should be loaded only once # If two plugins both require jQuery, jQuery should be loaded only once
response = app_client.get( response = app_client.get("/fixtures")
"/fixtures"
)
# This test is a little tricky, as if the user has any other plugins in # This test is a little tricky, as if the user has any other plugins in
# their current virtual environment those may affect what comes back too. # their current virtual environment those may affect what comes back too.
# What matters is that https://example.com/jquery.js is only there once # What matters is that https://example.com/jquery.js is only there once
# and it comes before plugin1.js and plugin2.js which could be in either # and it comes before plugin1.js and plugin2.js which could be in either
# order # order
scripts = Soup(response.body, 'html.parser').findAll('script') scripts = Soup(response.body, "html.parser").findAll("script")
srcs = [s['src'] for s in scripts if s.get('src')] srcs = [s["src"] for s in scripts if s.get("src")]
# No duplicates allowed: # No duplicates allowed:
assert len(srcs) == len(set(srcs)) assert len(srcs) == len(set(srcs))
# jquery.js loaded once: # jquery.js loaded once:
assert 1 == srcs.count('https://example.com/jquery.js') assert 1 == srcs.count("https://example.com/jquery.js")
# plugin1.js and plugin2.js are both there: # plugin1.js and plugin2.js are both there:
assert 1 == srcs.count('https://example.com/plugin1.js') assert 1 == srcs.count("https://example.com/plugin1.js")
assert 1 == srcs.count('https://example.com/plugin2.js') assert 1 == srcs.count("https://example.com/plugin2.js")
# jquery comes before them both # jquery comes before them both
assert srcs.index( assert srcs.index("https://example.com/jquery.js") < srcs.index(
'https://example.com/jquery.js' "https://example.com/plugin1.js"
) < srcs.index(
'https://example.com/plugin1.js'
) )
assert srcs.index( assert srcs.index("https://example.com/jquery.js") < srcs.index(
'https://example.com/jquery.js' "https://example.com/plugin2.js"
) < srcs.index(
'https://example.com/plugin2.js'
) )
@ -107,13 +86,9 @@ def test_plugins_render_cell_link_from_json(app_client):
sql = """ sql = """
select '{"href": "http://example.com/", "label":"Example"}' select '{"href": "http://example.com/", "label":"Example"}'
""".strip() """.strip()
path = "/fixtures?" + urllib.parse.urlencode({ path = "/fixtures?" + urllib.parse.urlencode({"sql": sql})
"sql": sql,
})
response = app_client.get(path) response = app_client.get(path)
td = Soup( td = Soup(response.body, "html.parser").find("table").find("tbody").find("td")
response.body, "html.parser"
).find("table").find("tbody").find("td")
a = td.find("a") a = td.find("a")
assert a is not None, str(a) assert a is not None, str(a)
assert a.attrs["href"] == "http://example.com/" assert a.attrs["href"] == "http://example.com/"
@ -129,10 +104,7 @@ def test_plugins_render_cell_demo(app_client):
"column": "content", "column": "content",
"table": "simple_primary_key", "table": "simple_primary_key",
"database": "fixtures", "database": "fixtures",
"config": { "config": {"depth": "table", "special": "this-is-simple_primary_key"},
"depth": "table",
"special": "this-is-simple_primary_key"
}
} == json.loads(td.string) } == json.loads(td.string)

View file

@ -35,7 +35,14 @@ def test_publish_cloudrun(mock_call, mock_output, mock_which):
result = runner.invoke(cli.cli, ["publish", "cloudrun", "test.db"]) result = runner.invoke(cli.cli, ["publish", "cloudrun", "test.db"])
assert 0 == result.exit_code assert 0 == result.exit_code
tag = "gcr.io/{}/datasette".format(mock_output.return_value) tag = "gcr.io/{}/datasette".format(mock_output.return_value)
mock_call.assert_has_calls([ mock_call.assert_has_calls(
[
mock.call("gcloud builds submit --tag {}".format(tag), shell=True), mock.call("gcloud builds submit --tag {}".format(tag), shell=True),
mock.call("gcloud beta run deploy --allow-unauthenticated --image {}".format(tag), shell=True)]) mock.call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
tag
),
shell=True,
),
]
)

View file

@ -57,7 +57,9 @@ def test_publish_heroku(mock_call, mock_check_output, mock_which):
open("test.db", "w").write("data") open("test.db", "w").write("data")
result = runner.invoke(cli.cli, ["publish", "heroku", "test.db"]) result = runner.invoke(cli.cli, ["publish", "heroku", "test.db"])
assert 0 == result.exit_code, result.output assert 0 == result.exit_code, result.output
mock_call.assert_called_once_with(["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"]) mock_call.assert_called_once_with(
["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"]
)
@mock.patch("shutil.which") @mock.patch("shutil.which")

View file

@ -13,72 +13,78 @@ import tempfile
from unittest.mock import patch from unittest.mock import patch
@pytest.mark.parametrize('path,expected', [ @pytest.mark.parametrize(
('foo', ['foo']), "path,expected",
('foo,bar', ['foo', 'bar']), [
('123,433,112', ['123', '433', '112']), ("foo", ["foo"]),
('123%2C433,112', ['123,433', '112']), ("foo,bar", ["foo", "bar"]),
('123%2F433%2F112', ['123/433/112']), ("123,433,112", ["123", "433", "112"]),
]) ("123%2C433,112", ["123,433", "112"]),
("123%2F433%2F112", ["123/433/112"]),
],
)
def test_urlsafe_components(path, expected): def test_urlsafe_components(path, expected):
assert expected == utils.urlsafe_components(path) assert expected == utils.urlsafe_components(path)
@pytest.mark.parametrize('path,added_args,expected', [ @pytest.mark.parametrize(
('/foo', {'bar': 1}, '/foo?bar=1'), "path,added_args,expected",
('/foo?bar=1', {'baz': 2}, '/foo?bar=1&baz=2'), [
('/foo?bar=1&bar=2', {'baz': 3}, '/foo?bar=1&bar=2&baz=3'), ("/foo", {"bar": 1}, "/foo?bar=1"),
('/foo?bar=1', {'bar': None}, '/foo'), ("/foo?bar=1", {"baz": 2}, "/foo?bar=1&baz=2"),
("/foo?bar=1&bar=2", {"baz": 3}, "/foo?bar=1&bar=2&baz=3"),
("/foo?bar=1", {"bar": None}, "/foo"),
# Test order is preserved # Test order is preserved
('/?_facet=prim_state&_facet=area_name', ( (
('prim_state', 'GA'), "/?_facet=prim_state&_facet=area_name",
), '/?_facet=prim_state&_facet=area_name&prim_state=GA'), (("prim_state", "GA"),),
('/?_facet=state&_facet=city&state=MI', ( "/?_facet=prim_state&_facet=area_name&prim_state=GA",
('city', 'Detroit'), ),
), '/?_facet=state&_facet=city&state=MI&city=Detroit'), (
('/?_facet=state&_facet=city', ( "/?_facet=state&_facet=city&state=MI",
('_facet', 'planet_int'), (("city", "Detroit"),),
), '/?_facet=state&_facet=city&_facet=planet_int'), "/?_facet=state&_facet=city&state=MI&city=Detroit",
]) ),
def test_path_with_added_args(path, added_args, expected): (
request = Request( "/?_facet=state&_facet=city",
path.encode('utf8'), (("_facet", "planet_int"),),
{}, '1.1', 'GET', None "/?_facet=state&_facet=city&_facet=planet_int",
),
],
) )
def test_path_with_added_args(path, added_args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_added_args(request, added_args) actual = utils.path_with_added_args(request, added_args)
assert expected == actual assert expected == actual
@pytest.mark.parametrize('path,args,expected', [ @pytest.mark.parametrize(
('/foo?bar=1', {'bar'}, '/foo'), "path,args,expected",
('/foo?bar=1&baz=2', {'bar'}, '/foo?baz=2'), [
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'), ("/foo?bar=1", {"bar"}, "/foo"),
]) ("/foo?bar=1&baz=2", {"bar"}, "/foo?baz=2"),
def test_path_with_removed_args(path, args, expected): ("/foo?bar=1&bar=2&bar=3", {"bar": "2"}, "/foo?bar=1&bar=3"),
request = Request( ],
path.encode('utf8'),
{}, '1.1', 'GET', None
) )
def test_path_with_removed_args(path, args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_removed_args(request, args) actual = utils.path_with_removed_args(request, args)
assert expected == actual assert expected == actual
# Run the test again but this time use the path= argument # Run the test again but this time use the path= argument
request = Request( request = Request("/".encode("utf8"), {}, "1.1", "GET", None)
"/".encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_removed_args(request, args, path=path) actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual assert expected == actual
@pytest.mark.parametrize('path,args,expected', [ @pytest.mark.parametrize(
('/foo?bar=1', {'bar': 2}, '/foo?bar=2'), "path,args,expected",
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'), [
]) ("/foo?bar=1", {"bar": 2}, "/foo?bar=2"),
def test_path_with_replaced_args(path, args, expected): ("/foo?bar=1&baz=2", {"bar": None}, "/foo?baz=2"),
request = Request( ],
path.encode('utf8'),
{}, '1.1', 'GET', None
) )
def test_path_with_replaced_args(path, args, expected):
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_replaced_args(request, args) actual = utils.path_with_replaced_args(request, args)
assert expected == actual assert expected == actual
@ -93,17 +99,8 @@ def test_path_with_replaced_args(path, args, expected):
utils.CustomRow( utils.CustomRow(
["searchable_id", "tag"], ["searchable_id", "tag"],
[ [
( ("searchable_id", {"value": 1, "label": "1"}),
"searchable_id", ("tag", {"value": "feline", "label": "feline"}),
{"value": 1, "label": "1"},
),
(
"tag",
{
"value": "feline",
"label": "feline",
},
),
], ],
), ),
["searchable_id", "tag"], ["searchable_id", "tag"],
@ -116,47 +113,54 @@ def test_path_from_row_pks(row, pks, expected_path):
assert expected_path == actual_path assert expected_path == actual_path
@pytest.mark.parametrize('obj,expected', [ @pytest.mark.parametrize(
({ "obj,expected",
'Description': 'Soft drinks', [
'Picture': b"\x15\x1c\x02\xc7\xad\x05\xfe", (
'CategoryID': 1, {
}, """ "Description": "Soft drinks",
"Picture": b"\x15\x1c\x02\xc7\xad\x05\xfe",
"CategoryID": 1,
},
"""
{"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}} {"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}}
""".strip()), """.strip(),
])
def test_custom_json_encoder(obj, expected):
actual = json.dumps(
obj,
cls=utils.CustomJSONEncoder,
sort_keys=True
) )
],
)
def test_custom_json_encoder(obj, expected):
actual = json.dumps(obj, cls=utils.CustomJSONEncoder, sort_keys=True)
assert expected == actual assert expected == actual
@pytest.mark.parametrize('bad_sql', [ @pytest.mark.parametrize(
'update blah;', "bad_sql",
'PRAGMA case_sensitive_like = true' [
"SELECT * FROM pragma_index_info('idx52')", "update blah;",
]) "PRAGMA case_sensitive_like = true" "SELECT * FROM pragma_index_info('idx52')",
],
)
def test_validate_sql_select_bad(bad_sql): def test_validate_sql_select_bad(bad_sql):
with pytest.raises(utils.InvalidSql): with pytest.raises(utils.InvalidSql):
utils.validate_sql_select(bad_sql) utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize('good_sql', [ @pytest.mark.parametrize(
'select count(*) from airports', "good_sql",
'select foo from bar', [
'select 1 + 1', "select count(*) from airports",
'SELECT\nblah FROM foo', "select foo from bar",
'WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;' "select 1 + 1",
]) "SELECT\nblah FROM foo",
"WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
],
)
def test_validate_sql_select_good(good_sql): def test_validate_sql_select_good(good_sql):
utils.validate_sql_select(good_sql) utils.validate_sql_select(good_sql)
def test_detect_fts(): def test_detect_fts():
sql = ''' sql = """
CREATE TABLE "Dumb_Table" ( CREATE TABLE "Dumb_Table" (
"TreeID" INTEGER, "TreeID" INTEGER,
"qSpecies" TEXT "qSpecies" TEXT
@ -173,34 +177,40 @@ def test_detect_fts():
CREATE VIEW Test_View AS SELECT * FROM Dumb_Table; CREATE VIEW Test_View AS SELECT * FROM Dumb_Table;
CREATE VIRTUAL TABLE "Street_Tree_List_fts" USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content="Street_Tree_List"); CREATE VIRTUAL TABLE "Street_Tree_List_fts" USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content="Street_Tree_List");
CREATE VIRTUAL TABLE r USING rtree(a, b, c); CREATE VIRTUAL TABLE r USING rtree(a, b, c);
''' """
conn = utils.sqlite3.connect(':memory:') conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql) conn.executescript(sql)
assert None is utils.detect_fts(conn, 'Dumb_Table') assert None is utils.detect_fts(conn, "Dumb_Table")
assert None is utils.detect_fts(conn, 'Test_View') assert None is utils.detect_fts(conn, "Test_View")
assert None is utils.detect_fts(conn, 'r') assert None is utils.detect_fts(conn, "r")
assert 'Street_Tree_List_fts' == utils.detect_fts(conn, 'Street_Tree_List') assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List")
@pytest.mark.parametrize('url,expected', [ @pytest.mark.parametrize(
('http://www.google.com/', True), "url,expected",
('https://example.com/', True), [
('www.google.com', False), ("http://www.google.com/", True),
('http://www.google.com/ is a search engine', False), ("https://example.com/", True),
]) ("www.google.com", False),
("http://www.google.com/ is a search engine", False),
],
)
def test_is_url(url, expected): def test_is_url(url, expected):
assert expected == utils.is_url(url) assert expected == utils.is_url(url)
@pytest.mark.parametrize('s,expected', [ @pytest.mark.parametrize(
('simple', 'simple'), "s,expected",
('MixedCase', 'MixedCase'), [
('-no-leading-hyphens', 'no-leading-hyphens-65bea6'), ("simple", "simple"),
('_no-leading-underscores', 'no-leading-underscores-b921bc'), ("MixedCase", "MixedCase"),
('no spaces', 'no-spaces-7088d7'), ("-no-leading-hyphens", "no-leading-hyphens-65bea6"),
('-', '336d5e'), ("_no-leading-underscores", "no-leading-underscores-b921bc"),
('no $ characters', 'no--characters-59e024'), ("no spaces", "no-spaces-7088d7"),
]) ("-", "336d5e"),
("no $ characters", "no--characters-59e024"),
],
)
def test_to_css_class(s, expected): def test_to_css_class(s, expected):
assert expected == utils.to_css_class(s) assert expected == utils.to_css_class(s)
@ -208,11 +218,11 @@ def test_to_css_class(s, expected):
def test_temporary_docker_directory_uses_hard_link(): def test_temporary_docker_directory_uses_hard_link():
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
os.chdir(td) os.chdir(td)
open('hello', 'w').write('world') open("hello", "w").write("world")
# Default usage of this should use symlink # Default usage of this should use symlink
with utils.temporary_docker_directory( with utils.temporary_docker_directory(
files=['hello'], files=["hello"],
name='t', name="t",
metadata=None, metadata=None,
extra_options=None, extra_options=None,
branch=None, branch=None,
@ -223,23 +233,23 @@ def test_temporary_docker_directory_uses_hard_link():
spatialite=False, spatialite=False,
version_note=None, version_note=None,
) as temp_docker: ) as temp_docker:
hello = os.path.join(temp_docker, 'hello') hello = os.path.join(temp_docker, "hello")
assert 'world' == open(hello).read() assert "world" == open(hello).read()
# It should be a hard link # It should be a hard link
assert 2 == os.stat(hello).st_nlink assert 2 == os.stat(hello).st_nlink
@patch('os.link') @patch("os.link")
def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
# Copy instead if os.link raises OSError (normally due to different device) # Copy instead if os.link raises OSError (normally due to different device)
mock_link.side_effect = OSError mock_link.side_effect = OSError
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
os.chdir(td) os.chdir(td)
open('hello', 'w').write('world') open("hello", "w").write("world")
# Default usage of this should use symlink # Default usage of this should use symlink
with utils.temporary_docker_directory( with utils.temporary_docker_directory(
files=['hello'], files=["hello"],
name='t', name="t",
metadata=None, metadata=None,
extra_options=None, extra_options=None,
branch=None, branch=None,
@ -250,8 +260,8 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
spatialite=False, spatialite=False,
version_note=None, version_note=None,
) as temp_docker: ) as temp_docker:
hello = os.path.join(temp_docker, 'hello') hello = os.path.join(temp_docker, "hello")
assert 'world' == open(hello).read() assert "world" == open(hello).read()
# It should be a copy, not a hard link # It should be a copy, not a hard link
assert 1 == os.stat(hello).st_nlink assert 1 == os.stat(hello).st_nlink
@ -259,40 +269,44 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
def test_temporary_docker_directory_quotes_args(): def test_temporary_docker_directory_quotes_args():
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
os.chdir(td) os.chdir(td)
open('hello', 'w').write('world') open("hello", "w").write("world")
with utils.temporary_docker_directory( with utils.temporary_docker_directory(
files=['hello'], files=["hello"],
name='t', name="t",
metadata=None, metadata=None,
extra_options='--$HOME', extra_options="--$HOME",
branch=None, branch=None,
template_dir=None, template_dir=None,
plugins_dir=None, plugins_dir=None,
static=[], static=[],
install=[], install=[],
spatialite=False, spatialite=False,
version_note='$PWD', version_note="$PWD",
) as temp_docker: ) as temp_docker:
df = os.path.join(temp_docker, 'Dockerfile') df = os.path.join(temp_docker, "Dockerfile")
df_contents = open(df).read() df_contents = open(df).read()
assert "'$PWD'" in df_contents assert "'$PWD'" in df_contents
assert "'--$HOME'" in df_contents assert "'--$HOME'" in df_contents
def test_compound_keys_after_sql(): def test_compound_keys_after_sql():
assert '((a > :p0))' == utils.compound_keys_after_sql(['a']) assert "((a > :p0))" == utils.compound_keys_after_sql(["a"])
assert ''' assert """
((a > :p0) ((a > :p0)
or or
(a = :p0 and b > :p1)) (a = :p0 and b > :p1))
'''.strip() == utils.compound_keys_after_sql(['a', 'b']) """.strip() == utils.compound_keys_after_sql(
assert ''' ["a", "b"]
)
assert """
((a > :p0) ((a > :p0)
or or
(a = :p0 and b > :p1) (a = :p0 and b > :p1)
or or
(a = :p0 and b = :p1 and c > :p2)) (a = :p0 and b = :p1 and c > :p2))
'''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) """.strip() == utils.compound_keys_after_sql(
["a", "b", "c"]
)
async def table_exists(table): async def table_exists(table):
@ -314,7 +328,7 @@ async def test_resolve_table_and_format(
table_and_format, expected_table, expected_format table_and_format, expected_table, expected_format
): ):
actual_table, actual_format = await utils.resolve_table_and_format( actual_table, actual_format = await utils.resolve_table_and_format(
table_and_format, table_exists, ['json'] table_and_format, table_exists, ["json"]
) )
assert expected_table == actual_table assert expected_table == actual_table
assert expected_format == actual_format assert expected_format == actual_format
@ -322,9 +336,11 @@ async def test_resolve_table_and_format(
def test_table_columns(): def test_table_columns():
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
conn.executescript(""" conn.executescript(
"""
create table places (id integer primary key, name text, bob integer) create table places (id integer primary key, name text, bob integer)
""") """
)
assert ["id", "name", "bob"] == utils.table_columns(conn, "places") assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
@ -347,10 +363,7 @@ def test_table_columns():
], ],
) )
def test_path_with_format(path, format, extra_qs, expected): def test_path_with_format(path, format, extra_qs, expected):
request = Request( request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_format(request, format, extra_qs) actual = utils.path_with_format(request, format, extra_qs)
assert expected == actual assert expected == actual
@ -358,13 +371,13 @@ def test_path_with_format(path, format, extra_qs, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bytes,expected", "bytes,expected",
[ [
(120, '120 bytes'), (120, "120 bytes"),
(1024, '1.0 KB'), (1024, "1.0 KB"),
(1024 * 1024, '1.0 MB'), (1024 * 1024, "1.0 MB"),
(1024 * 1024 * 1024, '1.0 GB'), (1024 * 1024 * 1024, "1.0 GB"),
(1024 * 1024 * 1024 * 1.3, '1.3 GB'), (1024 * 1024 * 1024 * 1.3, "1.3 GB"),
(1024 * 1024 * 1024 * 1024, '1.0 TB'), (1024 * 1024 * 1024 * 1024, "1.0 TB"),
] ],
) )
def test_format_bytes(bytes, expected): def test_format_bytes(bytes, expected):
assert expected == utils.format_bytes(bytes) assert expected == utils.format_bytes(bytes)