Apply black to everything

I ran this:

    black datasette tests
This commit is contained in:
Simon Willison 2019-05-03 17:56:52 -04:00
commit 6300d3e269
28 changed files with 2725 additions and 2644 deletions

View file

@ -1,3 +1,3 @@
from datasette.version import __version_info__, __version__ # noqa
from .hookspecs import hookimpl # noqa
from .hookspecs import hookspec # noqa
from .hookspecs import hookimpl # noqa
from .hookspecs import hookspec # noqa

View file

@ -1,4 +1,3 @@
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@ -58,17 +57,18 @@ HANDLERS = {}
def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None))
p = subprocess.Popen(
[c] + args,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr else None),
)
break
except EnvironmentError:
e = sys.exc_info()[1]
@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
return {
"version": dirname[len(parentdir_prefix) :],
"full-revisionid": None,
"dirty": False,
"error": None,
"date": None,
}
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
print(
"Tried directories %s but none started with prefix %s"
% (str(rootdirs), parentdir_prefix)
)
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = set([r for r in refs if re.search(r"\d", r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
r = ref[len(tag_prefix) :]
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
return {
"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": None,
"date": date,
}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
return {
"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": "no suitable tags",
"date": None,
}
@register_vcs_handler("git", "pieces_from_vcs")
@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=True)
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
"--always", "--long",
"--match", "%s*" % tag_prefix],
cwd=root)
describe_out, rc = run_command(
GITS,
[
"describe",
"--tags",
"--dirty",
"--always",
"--long",
"--match",
"%s*" % tag_prefix,
],
cwd=root,
)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
full_tag,
tag_prefix,
)
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
cwd=root)
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[
0
].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
@ -330,8 +355,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@ -445,11 +469,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
return {
"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None,
}
if not style or style == "default":
style = "pep440" # the default
@ -469,9 +495,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
return {
"version": rendered,
"full-revisionid": pieces["long"],
"dirty": pieces["dirty"],
"error": None,
"date": pieces.get("date"),
}
def get_versions():
@ -485,8 +515,7 @@ def get_versions():
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
@ -495,13 +524,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for i in cfg.versionfile_source.split('/'):
for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None,
}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@ -515,6 +547,10 @@ def get_versions():
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to compute version",
"date": None,
}

View file

@ -17,10 +17,7 @@ from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import (
DatasetteError,
ureg
)
from .views.base import DatasetteError, ureg
from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView
from .views.special import JsonDataView
@ -39,7 +36,7 @@ from .utils import (
sqlite3,
sqlite_timelimit,
table_columns,
to_css_class
to_css_class,
)
from .inspect import inspect_hash, inspect_views, inspect_tables
from .tracer import capture_traces, trace
@ -51,72 +48,143 @@ app_root = Path(__file__).parent.parent
connections = threading.local()
MEMORY = object()
ConfigOption = collections.namedtuple(
"ConfigOption", ("name", "default", "help")
)
ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help"))
CONFIG_OPTIONS = (
ConfigOption("default_page_size", 100, """
ConfigOption(
"default_page_size",
100,
"""
Default page size for the table view
""".strip()),
ConfigOption("max_returned_rows", 1000, """
""".strip(),
),
ConfigOption(
"max_returned_rows",
1000,
"""
Maximum rows that can be returned from a table or custom query
""".strip()),
ConfigOption("num_sql_threads", 3, """
""".strip(),
),
ConfigOption(
"num_sql_threads",
3,
"""
Number of threads in the thread pool for executing SQLite queries
""".strip()),
ConfigOption("sql_time_limit_ms", 1000, """
""".strip(),
),
ConfigOption(
"sql_time_limit_ms",
1000,
"""
Time limit for a SQL query in milliseconds
""".strip()),
ConfigOption("default_facet_size", 30, """
""".strip(),
),
ConfigOption(
"default_facet_size",
30,
"""
Number of values to return for requested facets
""".strip()),
ConfigOption("facet_time_limit_ms", 200, """
""".strip(),
),
ConfigOption(
"facet_time_limit_ms",
200,
"""
Time limit for calculating a requested facet
""".strip()),
ConfigOption("facet_suggest_time_limit_ms", 50, """
""".strip(),
),
ConfigOption(
"facet_suggest_time_limit_ms",
50,
"""
Time limit for calculating a suggested facet
""".strip()),
ConfigOption("hash_urls", False, """
""".strip(),
),
ConfigOption(
"hash_urls",
False,
"""
Include DB file contents hash in URLs, for far-future caching
""".strip()),
ConfigOption("allow_facet", True, """
""".strip(),
),
ConfigOption(
"allow_facet",
True,
"""
Allow users to specify columns to facet using ?_facet= parameter
""".strip()),
ConfigOption("allow_download", True, """
""".strip(),
),
ConfigOption(
"allow_download",
True,
"""
Allow users to download the original SQLite database files
""".strip()),
ConfigOption("suggest_facets", True, """
""".strip(),
),
ConfigOption(
"suggest_facets",
True,
"""
Calculate and display suggested facets
""".strip()),
ConfigOption("allow_sql", True, """
""".strip(),
),
ConfigOption(
"allow_sql",
True,
"""
Allow arbitrary SQL queries via ?sql= parameter
""".strip()),
ConfigOption("default_cache_ttl", 5, """
""".strip(),
),
ConfigOption(
"default_cache_ttl",
5,
"""
Default HTTP cache TTL (used in Cache-Control: max-age= header)
""".strip()),
ConfigOption("default_cache_ttl_hashed", 365 * 24 * 60 * 60, """
""".strip(),
),
ConfigOption(
"default_cache_ttl_hashed",
365 * 24 * 60 * 60,
"""
Default HTTP cache TTL for hashed URL pages
""".strip()),
ConfigOption("cache_size_kb", 0, """
""".strip(),
),
ConfigOption(
"cache_size_kb",
0,
"""
SQLite cache size in KB (0 == use SQLite default)
""".strip()),
ConfigOption("allow_csv_stream", True, """
""".strip(),
),
ConfigOption(
"allow_csv_stream",
True,
"""
Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)
""".strip()),
ConfigOption("max_csv_mb", 100, """
""".strip(),
),
ConfigOption(
"max_csv_mb",
100,
"""
Maximum size allowed for CSV export in MB - set 0 to disable this limit
""".strip()),
ConfigOption("truncate_cells_html", 2048, """
""".strip(),
),
ConfigOption(
"truncate_cells_html",
2048,
"""
Truncate cells longer than this in HTML table view - set 0 to disable
""".strip()),
ConfigOption("force_https_urls", False, """
""".strip(),
),
ConfigOption(
"force_https_urls",
False,
"""
Force URLs in API output to always use https:// protocol
""".strip()),
""".strip(),
),
)
DEFAULT_CONFIG = {
option.name: option.default
for option in CONFIG_OPTIONS
}
DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS}
async def favicon(request):
@ -151,11 +219,13 @@ class ConnectedDatabase:
counts = {}
for table in await self.table_names():
try:
table_count = (await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)).rows[0][0]
table_count = (
await self.ds.execute(
self.name,
"select count(*) from [{}]".format(table),
custom_time_limit=limit,
)
).rows[0][0]
counts[table] = table_count
except InterruptedError:
counts[table] = None
@ -175,18 +245,26 @@ class ConnectedDatabase:
return Path(self.path).stem
async def table_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='table'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='table'"
)
return [r[0] for r in results.rows]
async def hidden_table_names(self):
# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [r[0] for r in (
await self.ds.execute(self.name, """
hidden_tables = [
r[0]
for r in (
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
""")
).rows]
""",
)
).rows
]
has_spatialite = await self.ds.execute_against_connection_in_thread(
self.name, detect_spatialite
)
@ -205,18 +283,23 @@ class ConnectedDatabase:
] + [
r[0]
for r in (
await self.ds.execute(self.name, """
await self.ds.execute(
self.name,
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
""")
""",
)
).rows
]
# Add any from metadata.json
db_metadata = self.ds.metadata(database=self.name)
if "tables" in db_metadata:
hidden_tables += [
t for t in db_metadata["tables"] if db_metadata["tables"][t].get("hidden")
t
for t in db_metadata["tables"]
if db_metadata["tables"][t].get("hidden")
]
# Also mark as hidden any tables which start with the name of a hidden table
# e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
@ -229,7 +312,9 @@ class ConnectedDatabase:
return hidden_tables
async def view_names(self):
results = await self.ds.execute(self.name, "select name from sqlite_master where type='view'")
results = await self.ds.execute(
self.name, "select name from sqlite_master where type='view'"
)
return [r[0] for r in results.rows]
def __repr__(self):
@ -245,13 +330,10 @@ class ConnectedDatabase:
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format(
self.name, tags_str
)
return "<ConnectedDatabase: {}{}>".format(self.name, tags_str)
class Datasette:
def __init__(
self,
files,
@ -283,7 +365,9 @@ class Datasette:
path = None
is_memory = True
is_mutable = path not in self.immutables
db = ConnectedDatabase(self, path, is_mutable=is_mutable, is_memory=is_memory)
db = ConnectedDatabase(
self, path, is_mutable=is_mutable, is_memory=is_memory
)
if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db
@ -322,26 +406,24 @@ class Datasette:
def config_dict(self):
# Returns a fully resolved config dictionary, useful for templates
return {
option.name: self.config(option.name)
for option in CONFIG_OPTIONS
}
return {option.name: self.config(option.name) for option in CONFIG_OPTIONS}
def metadata(self, key=None, database=None, table=None, fallback=True):
"""
Looks up metadata, cascading backwards from specified level.
Returns None if metadata value is not found.
"""
assert not (database is None and table is not None), \
"Cannot call metadata() with table= specified but not database="
assert not (
database is None and table is not None
), "Cannot call metadata() with table= specified but not database="
databases = self._metadata.get("databases") or {}
search_list = []
if database is not None:
search_list.append(databases.get(database) or {})
if table is not None:
table_metadata = (
(databases.get(database) or {}).get("tables") or {}
).get(table) or {}
table_metadata = ((databases.get(database) or {}).get("tables") or {}).get(
table
) or {}
search_list.insert(0, table_metadata)
search_list.append(self._metadata)
if not fallback:
@ -359,9 +441,7 @@ class Datasette:
m.update(item)
return m
def plugin_config(
self, plugin_name, database=None, table=None, fallback=True
):
def plugin_config(self, plugin_name, database=None, table=None, fallback=True):
"Return config for plugin, falling back from specified database/table"
plugins = self.metadata(
"plugins", database=database, table=table, fallback=fallback
@ -373,29 +453,19 @@ class Datasette:
def app_css_hash(self):
if not hasattr(self, "_app_css_hash"):
self._app_css_hash = hashlib.sha1(
open(
os.path.join(str(app_root), "datasette/static/app.css")
).read().encode(
"utf8"
)
).hexdigest()[
:6
]
open(os.path.join(str(app_root), "datasette/static/app.css"))
.read()
.encode("utf8")
).hexdigest()[:6]
return self._app_css_hash
def get_canned_queries(self, database_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
names = queries.keys()
return [
self.get_canned_query(database_name, name) for name in names
]
return [self.get_canned_query(database_name, name) for name in names]
def get_canned_query(self, database_name, query_name):
queries = self.metadata(
"queries", database=database_name, fallback=False
) or {}
queries = self.metadata("queries", database=database_name, fallback=False) or {}
query = queries.get(query_name)
if query:
if not isinstance(query, dict):
@ -407,7 +477,7 @@ class Datasette:
table_definition_rows = list(
await self.execute(
database_name,
'select sql from sqlite_master where name = :n and type=:t',
"select sql from sqlite_master where name = :n and type=:t",
{"n": table, "t": type_},
)
)
@ -416,21 +486,19 @@ class Datasette:
return table_definition_rows[0][0]
def get_view_definition(self, database_name, view):
return self.get_table_definition(database_name, view, 'view')
return self.get_table_definition(database_name, view, "view")
def update_with_inherited_metadata(self, metadata):
# Fills in source/license with defaults, if available
metadata.update(
{
"source": metadata.get("source") or self.metadata("source"),
"source_url": metadata.get("source_url")
or self.metadata("source_url"),
"source_url": metadata.get("source_url") or self.metadata("source_url"),
"license": metadata.get("license") or self.metadata("license"),
"license_url": metadata.get("license_url")
or self.metadata("license_url"),
"about": metadata.get("about") or self.metadata("about"),
"about_url": metadata.get("about_url")
or self.metadata("about_url"),
"about_url": metadata.get("about_url") or self.metadata("about_url"),
}
)
@ -444,7 +512,7 @@ class Datasette:
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
if self.config("cache_size_kb"):
conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb")))
conn.execute("PRAGMA cache_size=-{}".format(self.config("cache_size_kb")))
# pylint: disable=no-member
pm.hook.prepare_connection(conn=conn)
@ -452,7 +520,7 @@ class Datasette:
results = await self.execute(
database,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,)
params=(table,),
)
return bool(results.rows)
@ -463,32 +531,28 @@ class Datasette:
# Find the foreign_key for this column
try:
fk = [
foreign_key for foreign_key in foreign_keys
foreign_key
for foreign_key in foreign_keys
if foreign_key["column"] == column
][0]
except IndexError:
return {}
label_column = await self.label_column_for_table(database, fk["other_table"])
if not label_column:
return {
(fk["column"], value): str(value)
for value in values
}
return {(fk["column"], value): str(value) for value in values}
labeled_fks = {}
sql = '''
sql = """
select {other_column}, {label_column}
from {other_table}
where {other_column} in ({placeholders})
'''.format(
""".format(
other_column=escape_sqlite(fk["other_column"]),
label_column=escape_sqlite(label_column),
other_table=escape_sqlite(fk["other_table"]),
placeholders=", ".join(["?"] * len(set(values))),
)
try:
results = await self.execute(
database, sql, list(set(values))
)
results = await self.execute(database, sql, list(set(values)))
except InterruptedError:
pass
else:
@ -499,7 +563,7 @@ class Datasette:
def absolute_url(self, request, path):
url = urllib.parse.urljoin(request.url, path)
if url.startswith("http://") and self.config("force_https_urls"):
url = "https://" + url[len("http://"):]
url = "https://" + url[len("http://") :]
return url
def inspect(self):
@ -532,10 +596,12 @@ class Datasette:
"file": str(path),
"size": path.stat().st_size,
"views": inspect_views(conn),
"tables": inspect_tables(conn, (self.metadata("databases") or {}).get(name, {}))
"tables": inspect_tables(
conn, (self.metadata("databases") or {}).get(name, {})
),
}
except sqlite3.OperationalError as e:
if (e.args[0] == 'no such module: VirtualSpatialIndex'):
if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
@ -582,7 +648,8 @@ class Datasette:
datasette_version["note"] = self.version_note
return {
"python": {
"version": ".".join(map(str, sys.version_info[:3])), "full": sys.version
"version": ".".join(map(str, sys.version_info[:3])),
"full": sys.version,
},
"datasette": datasette_version,
"sqlite": {
@ -611,10 +678,11 @@ class Datasette:
def table_metadata(self, database, table):
"Fetch table-specific metadata."
return (self.metadata("databases") or {}).get(database, {}).get(
"tables", {}
).get(
table, {}
return (
(self.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {})
)
async def table_columns(self, db_name, table):
@ -628,16 +696,12 @@ class Datasette:
)
async def label_column_for_table(self, db_name, table):
explicit_label_column = (
self.table_metadata(
db_name, table
).get("label_column")
)
explicit_label_column = self.table_metadata(db_name, table).get("label_column")
if explicit_label_column:
return explicit_label_column
# If a table has two columns, one of which is ID, then label_column is the other one
column_names = await self.table_columns(db_name, table)
if (column_names and len(column_names) == 2 and "id" in column_names):
if column_names and len(column_names) == 2 and "id" in column_names:
return [c for c in column_names if c != "id"][0]
# Couldn't find a label:
return None
@ -664,9 +728,7 @@ class Datasette:
setattr(connections, db_name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.executor, in_thread
)
return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread)
async def execute(
self,
@ -701,7 +763,7 @@ class Datasette:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
if e.args == ("interrupted",):
raise InterruptedError(e, sql, params)
if log_sql_errors:
print(
@ -726,7 +788,7 @@ class Datasette:
def register_renderers(self):
""" Register output renderers which output data in custom formats. """
# Built-in renderers
self.renderers['json'] = json_renderer
self.renderers["json"] = json_renderer
# Hooks
hook_renderers = []
@ -737,19 +799,22 @@ class Datasette:
hook_renderers.append(hook)
for renderer in hook_renderers:
self.renderers[renderer['extension']] = renderer['callback']
self.renderers[renderer["extension"]] = renderer["callback"]
def app(self):
class TracingSanic(Sanic):
async def handle_request(self, request, write_callback, stream_callback):
if request.args.get("_trace"):
request["traces"] = []
request["trace_start"] = time.time()
with capture_traces(request["traces"]):
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
else:
res = await super().handle_request(request, write_callback, stream_callback)
res = await super().handle_request(
request, write_callback, stream_callback
)
return res
app = TracingSanic(__name__)
@ -822,15 +887,16 @@ class Datasette:
)
app.add_route(
DatabaseView.as_view(self),
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>"
r"/<db_name:[^/]+?><as_format:(" + renderer_regex + r"|.jsono|\.csv)?$>",
)
app.add_route(
TableView.as_view(self),
r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>",
TableView.as_view(self), r"/<db_name:[^/]+>/<table_and_format:[^/]+?$>"
)
app.add_route(
RowView.as_view(self),
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:(" + renderer_regex + r")?$>",
r"/<db_name:[^/]+>/<table:[^/]+?>/<pk_path:[^/]+?><as_format:("
+ renderer_regex
+ r")?$>",
)
self.register_custom_units()
@ -852,7 +918,7 @@ class Datasette:
"duration": time.time() - request["trace_start"],
"queries": request["traces"],
}
if "text/html" in response.content_type and b'</body>' in response.body:
if "text/html" in response.content_type and b"</body>" in response.body:
extra = json.dumps(traces, indent=2)
extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
response.body = response.body.replace(b"</body>", extra_html)
@ -908,6 +974,6 @@ class Datasette:
async def setup_db(app, loop):
for dbname, database in self.databases.items():
if not database.is_mutable:
await database.table_counts(limit=60*60*1000)
await database.table_counts(limit=60 * 60 * 1000)
return app

View file

@ -20,16 +20,14 @@ class Config(click.ParamType):
def convert(self, config, param, ctx):
if ":" not in config:
self.fail(
'"{}" should be name:value'.format(config), param, ctx
)
self.fail('"{}" should be name:value'.format(config), param, ctx)
return
name, value = config.split(":")
if name not in DEFAULT_CONFIG:
self.fail(
"{} is not a valid option (--help-config to see all)".format(
name
), param, ctx
"{} is not a valid option (--help-config to see all)".format(name),
param,
ctx,
)
return
# Type checking
@ -44,14 +42,12 @@ class Config(click.ParamType):
return
elif isinstance(default, int):
if not value.isdigit():
self.fail(
'"{}" should be an integer'.format(name), param, ctx
)
self.fail('"{}" should be an integer'.format(name), param, ctx)
return
return name, int(value)
else:
# Should never happen:
self.fail('Invalid option')
self.fail("Invalid option")
@click.group(cls=DefaultGroup, default="serve", default_if_no_args=True)
@ -204,13 +200,9 @@ def plugins(all, plugins_dir):
multiple=True,
)
@click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
)
@click.option(
"--spatialite", is_flag=True, help="Enable SpatialLite extension"
"--install", help="Additional packages (e.g. plugins) to install", multiple=True
)
@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension")
@click.option("--version-note", help="Additional note to show on /-/versions")
@click.option("--title", help="Title for metadata")
@click.option("--license", help="License label for metadata")
@ -322,9 +314,7 @@ def package(
help="mountpoint:path-to-directory for serving static files",
multiple=True,
)
@click.option(
"--memory", is_flag=True, help="Make :memory: database available"
)
@click.option("--memory", is_flag=True, help="Make :memory: database available")
@click.option(
"--config",
type=Config(),
@ -332,11 +322,7 @@ def package(
multiple=True,
)
@click.option("--version-note", help="Additional note to show on /-/versions")
@click.option(
"--help-config",
is_flag=True,
help="Show available config options",
)
@click.option("--help-config", is_flag=True, help="Show available config options")
def serve(
files,
immutable,
@ -360,12 +346,12 @@ def serve(
if help_config:
formatter = formatting.HelpFormatter()
with formatter.section("Config options"):
formatter.write_dl([
(option.name, '{} (default={})'.format(
option.help, option.default
))
for option in CONFIG_OPTIONS
])
formatter.write_dl(
[
(option.name, "{} (default={})".format(option.help, option.default))
for option in CONFIG_OPTIONS
]
)
click.echo(formatter.getvalue())
sys.exit(0)
if reload:
@ -384,7 +370,9 @@ def serve(
if metadata:
metadata_data = json.loads(metadata.read())
click.echo("Serve! files={} (immutables={}) on port {}".format(files, immutable, port))
click.echo(
"Serve! files={} (immutables={}) on port {}".format(files, immutable, port)
)
ds = Datasette(
files,
immutables=immutable,

View file

@ -31,14 +31,15 @@ def load_facet_configs(request, table_metadata):
metadata_config = {"simple": metadata_config}
else:
# This should have a single key and a single value
assert len(metadata_config.values()) == 1, "Metadata config dicts should be {type: config}"
assert (
len(metadata_config.values()) == 1
), "Metadata config dicts should be {type: config}"
type, metadata_config = metadata_config.items()[0]
if isinstance(metadata_config, str):
metadata_config = {"simple": metadata_config}
facet_configs.setdefault(type, []).append({
"source": "metadata",
"config": metadata_config
})
facet_configs.setdefault(type, []).append(
{"source": "metadata", "config": metadata_config}
)
qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
for key, values in qs_pairs.items():
if key.startswith("_facet"):
@ -53,10 +54,9 @@ def load_facet_configs(request, table_metadata):
config = json.loads(value)
else:
config = {"simple": value}
facet_configs.setdefault(type, []).append({
"source": "request",
"config": config
})
facet_configs.setdefault(type, []).append(
{"source": "request", "config": config}
)
return facet_configs
@ -214,7 +214,9 @@ class ColumnFacet(Facet):
"name": column,
"type": self.type,
"hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet": column}),
"toggle_url": path_with_removed_args(
self.request, {"_facet": column}
),
"results": facet_results_values,
"truncated": len(facet_rows_results) > facet_size,
}
@ -269,30 +271,31 @@ class ArrayFacet(Facet):
select distinct json_type({column})
from ({sql})
""".format(
column=escape_sqlite(column),
sql=self.sql,
column=escape_sqlite(column), sql=self.sql
)
try:
results = await self.ds.execute(
self.database, suggested_facet_sql, self.params,
self.database,
suggested_facet_sql,
self.params,
truncate=False,
custom_time_limit=self.ds.config("facet_suggest_time_limit_ms"),
log_sql_errors=False,
)
types = tuple(r[0] for r in results.rows)
if types in (
("array",),
("array", None)
):
suggested_facets.append({
"name": column,
"type": "array",
"toggle_url": self.ds.absolute_url(
self.request, path_with_added_args(
self.request, {"_facet_array": column}
)
),
})
if types in (("array",), ("array", None)):
suggested_facets.append(
{
"name": column,
"type": "array",
"toggle_url": self.ds.absolute_url(
self.request,
path_with_added_args(
self.request, {"_facet_array": column}
),
),
}
)
except (InterruptedError, sqlite3.OperationalError):
continue
return suggested_facets
@ -314,13 +317,13 @@ class ArrayFacet(Facet):
) join json_each({col}) j
group by j.value order by count desc limit {limit}
""".format(
col=escape_sqlite(column),
sql=self.sql,
limit=facet_size+1,
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1
)
try:
facet_rows_results = await self.ds.execute(
self.database, facet_sql, self.params,
self.database,
facet_sql,
self.params,
truncate=False,
custom_time_limit=self.ds.config("facet_time_limit_ms"),
)
@ -330,7 +333,9 @@ class ArrayFacet(Facet):
"type": self.type,
"results": facet_results_values,
"hideable": source != "metadata",
"toggle_url": path_with_removed_args(self.request, {"_facet_array": column}),
"toggle_url": path_with_removed_args(
self.request, {"_facet_array": column}
),
"truncated": len(facet_rows_results) > facet_size,
}
facet_rows = facet_rows_results.rows[:facet_size]
@ -346,13 +351,17 @@ class ArrayFacet(Facet):
toggle_path = path_with_added_args(
self.request, {"{}__arraycontains".format(column): value}
)
facet_results_values.append({
"value": value,
"label": value,
"count": row["count"],
"toggle_url": self.ds.absolute_url(self.request, toggle_path),
"selected": selected,
})
facet_results_values.append(
{
"value": value,
"label": value,
"count": row["count"],
"toggle_url": self.ds.absolute_url(
self.request, toggle_path
),
"selected": selected,
}
)
except InterruptedError:
facets_timed_out.append(column)

View file

@ -1,10 +1,7 @@
import json
import numbers
from .utils import (
detect_json1,
escape_sqlite,
)
from .utils import detect_json1, escape_sqlite
class Filter:
@ -20,7 +17,16 @@ class Filter:
class TemplatedFilter(Filter):
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
def __init__(
self,
key,
display,
sql_template,
human_template,
format="{}",
numeric=False,
no_argument=False,
):
self.key = key
self.display = display
self.sql_template = sql_template
@ -34,16 +40,10 @@ class TemplatedFilter(Filter):
if self.numeric and converted.isdigit():
converted = int(converted)
if self.no_argument:
kwargs = {
'c': column,
}
kwargs = {"c": column}
converted = None
else:
kwargs = {
'c': column,
'p': 'p{}'.format(param_counter),
't': table,
}
kwargs = {"c": column, "p": "p{}".format(param_counter), "t": table}
return self.sql_template.format(**kwargs), converted
def human_clause(self, column, value):
@ -58,8 +58,8 @@ class TemplatedFilter(Filter):
class InFilter(Filter):
key = 'in'
display = 'in'
key = "in"
display = "in"
def __init__(self):
pass
@ -81,34 +81,98 @@ class InFilter(Filter):
class Filters:
_filters = [
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'),
TemplatedFilter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'),
TemplatedFilter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'),
TemplatedFilter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'),
TemplatedFilter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'),
TemplatedFilter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True),
TemplatedFilter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True),
TemplatedFilter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True),
TemplatedFilter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True),
TemplatedFilter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'),
TemplatedFilter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'),
InFilter(),
] + ([TemplatedFilter('arraycontains', 'array contains', """rowid in (
_filters = (
[
# key, display, sql_template, human_template, format=, numeric=, no_argument=
TemplatedFilter(
"exact",
"=",
'"{c}" = :{p}',
lambda c, v: "{c} = {v}" if v.isdigit() else '{c} = "{v}"',
),
TemplatedFilter(
"not",
"!=",
'"{c}" != :{p}',
lambda c, v: "{c} != {v}" if v.isdigit() else '{c} != "{v}"',
),
TemplatedFilter(
"contains",
"contains",
'"{c}" like :{p}',
'{c} contains "{v}"',
format="%{}%",
),
TemplatedFilter(
"endswith",
"ends with",
'"{c}" like :{p}',
'{c} ends with "{v}"',
format="%{}",
),
TemplatedFilter(
"startswith",
"starts with",
'"{c}" like :{p}',
'{c} starts with "{v}"',
format="{}%",
),
TemplatedFilter("gt", ">", '"{c}" > :{p}', "{c} > {v}", numeric=True),
TemplatedFilter(
"gte", "\u2265", '"{c}" >= :{p}', "{c} \u2265 {v}", numeric=True
),
TemplatedFilter("lt", "<", '"{c}" < :{p}', "{c} < {v}", numeric=True),
TemplatedFilter(
"lte", "\u2264", '"{c}" <= :{p}', "{c} \u2264 {v}", numeric=True
),
TemplatedFilter("like", "like", '"{c}" like :{p}', '{c} like "{v}"'),
TemplatedFilter("glob", "glob", '"{c}" glob :{p}', '{c} glob "{v}"'),
InFilter(),
]
+ (
[
TemplatedFilter(
"arraycontains",
"array contains",
"""rowid in (
select {t}.rowid from {t}, json_each({t}.{c}) j
where j.value = :{p}
)""", '{c} contains "{v}"')
] if detect_json1() else []) + [
TemplatedFilter('date', 'date', 'date({c}) = :{p}', '"{c}" is on date {v}'),
TemplatedFilter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True),
TemplatedFilter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True),
TemplatedFilter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True),
TemplatedFilter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True),
]
_filters_by_key = {
f.key: f for f in _filters
}
)""",
'{c} contains "{v}"',
)
]
if detect_json1()
else []
)
+ [
TemplatedFilter("date", "date", "date({c}) = :{p}", '"{c}" is on date {v}'),
TemplatedFilter(
"isnull", "is null", '"{c}" is null', "{c} is null", no_argument=True
),
TemplatedFilter(
"notnull",
"is not null",
'"{c}" is not null',
"{c} is not null",
no_argument=True,
),
TemplatedFilter(
"isblank",
"is blank",
'("{c}" is null or "{c}" = "")',
"{c} is blank",
no_argument=True,
),
TemplatedFilter(
"notblank",
"is not blank",
'("{c}" is not null and "{c}" != "")',
"{c} is not blank",
no_argument=True,
),
]
)
_filters_by_key = {f.key: f for f in _filters}
def __init__(self, pairs, units={}, ureg=None):
self.pairs = pairs
@ -132,22 +196,22 @@ class Filters:
and_bits = []
commas, tail = bits[:-1], bits[-1:]
if commas:
and_bits.append(', '.join(commas))
and_bits.append(", ".join(commas))
if tail:
and_bits.append(tail[0])
s = ' and '.join(and_bits)
s = " and ".join(and_bits)
if not s:
return ''
return 'where {}'.format(s)
return ""
return "where {}".format(s)
def selections(self):
"Yields (column, lookup, value) tuples"
for key, value in self.pairs:
if '__' in key:
column, lookup = key.rsplit('__', 1)
if "__" in key:
column, lookup = key.rsplit("__", 1)
else:
column = key
lookup = 'exact'
lookup = "exact"
yield column, lookup, value
def has_selections(self):
@ -174,13 +238,15 @@ class Filters:
for column, lookup, value in self.selections():
filter = self._filters_by_key.get(lookup, None)
if filter:
sql_bit, param = filter.where_clause(table, column, self.convert_unit(column, value), i)
sql_bit, param = filter.where_clause(
table, column, self.convert_unit(column, value), i
)
sql_bits.append(sql_bit)
if param is not None:
if not isinstance(param, list):
param = [param]
for individual_param in param:
param_id = 'p{}'.format(i)
param_id = "p{}".format(i)
params[param_id] = individual_param
i += 1
return sql_bits, params

View file

@ -7,7 +7,7 @@ from .utils import (
escape_sqlite,
get_all_foreign_keys,
table_columns,
sqlite3
sqlite3,
)
@ -29,7 +29,9 @@ def inspect_hash(path):
def inspect_views(conn):
" List views in a database. "
return [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')]
return [
v[0] for v in conn.execute('select name from sqlite_master where type = "view"')
]
def inspect_tables(conn, database_metadata):
@ -37,15 +39,11 @@ def inspect_tables(conn, database_metadata):
tables = {}
table_names = [
r["name"]
for r in conn.execute(
'select * from sqlite_master where type="table"'
)
for r in conn.execute('select * from sqlite_master where type="table"')
]
for table in table_names:
table_metadata = database_metadata.get("tables", {}).get(
table, {}
)
table_metadata = database_metadata.get("tables", {}).get(table, {})
try:
count = conn.execute(

View file

@ -41,8 +41,12 @@ def publish_subcommand(publish):
name,
spatialite,
):
fail_if_publish_binary_not_installed("gcloud", "Google Cloud", "https://cloud.google.com/sdk/")
project = check_output("gcloud config get-value project", shell=True, universal_newlines=True).strip()
fail_if_publish_binary_not_installed(
"gcloud", "Google Cloud", "https://cloud.google.com/sdk/"
)
project = check_output(
"gcloud config get-value project", shell=True, universal_newlines=True
).strip()
with temporary_docker_directory(
files,
@ -68,4 +72,9 @@ def publish_subcommand(publish):
):
image_id = "gcr.io/{project}/{name}".format(project=project, name=name)
check_call("gcloud builds submit --tag {}".format(image_id), shell=True)
check_call("gcloud beta run deploy --allow-unauthenticated --image {}".format(image_id), shell=True)
check_call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
image_id
),
shell=True,
)

View file

@ -5,46 +5,54 @@ import sys
def add_common_publish_arguments_and_options(subcommand):
for decorator in reversed((
click.argument("files", type=click.Path(exists=True), nargs=-1),
click.option(
"-m",
"--metadata",
type=click.File(mode="r"),
help="Path to JSON file containing metadata to publish",
),
click.option("--extra-options", help="Extra options to pass to datasette serve"),
click.option("--branch", help="Install datasette from a GitHub branch e.g. master"),
click.option(
"--template-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom templates",
),
click.option(
"--plugins-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom plugins",
),
click.option(
"--static",
type=StaticMount(),
help="mountpoint:path-to-directory for serving static files",
multiple=True,
),
click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
),
click.option("--version-note", help="Additional note to show on /-/versions"),
click.option("--title", help="Title for metadata"),
click.option("--license", help="License label for metadata"),
click.option("--license_url", help="License URL for metadata"),
click.option("--source", help="Source label for metadata"),
click.option("--source_url", help="Source URL for metadata"),
click.option("--about", help="About label for metadata"),
click.option("--about_url", help="About URL for metadata"),
)):
for decorator in reversed(
(
click.argument("files", type=click.Path(exists=True), nargs=-1),
click.option(
"-m",
"--metadata",
type=click.File(mode="r"),
help="Path to JSON file containing metadata to publish",
),
click.option(
"--extra-options", help="Extra options to pass to datasette serve"
),
click.option(
"--branch", help="Install datasette from a GitHub branch e.g. master"
),
click.option(
"--template-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom templates",
),
click.option(
"--plugins-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
help="Path to directory containing custom plugins",
),
click.option(
"--static",
type=StaticMount(),
help="mountpoint:path-to-directory for serving static files",
multiple=True,
),
click.option(
"--install",
help="Additional packages (e.g. plugins) to install",
multiple=True,
),
click.option(
"--version-note", help="Additional note to show on /-/versions"
),
click.option("--title", help="Title for metadata"),
click.option("--license", help="License label for metadata"),
click.option("--license_url", help="License URL for metadata"),
click.option("--source", help="Source label for metadata"),
click.option("--source_url", help="Source URL for metadata"),
click.option("--about", help="About label for metadata"),
click.option("--about_url", help="About URL for metadata"),
)
):
subcommand = decorator(subcommand)
return subcommand

View file

@ -76,9 +76,7 @@ def publish_subcommand(publish):
"about_url": about_url,
},
):
now_json = {
"version": 1
}
now_json = {"version": 1}
if alias:
now_json["alias"] = alias
open("now.json", "w").write(json.dumps(now_json))

View file

@ -89,8 +89,4 @@ def json_renderer(args, data, view_name):
else:
body = json.dumps(data, cls=CustomJSONEncoder)
content_type = "application/json"
return {
"body": body,
"status_code": status_code,
"content_type": content_type
}
return {"body": body, "status_code": status_code, "content_type": content_type}

View file

@ -21,27 +21,29 @@ except ImportError:
import sqlite3
# From https://www.sqlite.org/lang_keywords.html
reserved_words = set((
'abort action add after all alter analyze and as asc attach autoincrement '
'before begin between by cascade case cast check collate column commit '
'conflict constraint create cross current_date current_time '
'current_timestamp database default deferrable deferred delete desc detach '
'distinct drop each else end escape except exclusive exists explain fail '
'for foreign from full glob group having if ignore immediate in index '
'indexed initially inner insert instead intersect into is isnull join key '
'left like limit match natural no not notnull null of offset on or order '
'outer plan pragma primary query raise recursive references regexp reindex '
'release rename replace restrict right rollback row savepoint select set '
'table temp temporary then to transaction trigger union unique update using '
'vacuum values view virtual when where with without'
).split())
reserved_words = set(
(
"abort action add after all alter analyze and as asc attach autoincrement "
"before begin between by cascade case cast check collate column commit "
"conflict constraint create cross current_date current_time "
"current_timestamp database default deferrable deferred delete desc detach "
"distinct drop each else end escape except exclusive exists explain fail "
"for foreign from full glob group having if ignore immediate in index "
"indexed initially inner insert instead intersect into is isnull join key "
"left like limit match natural no not notnull null of offset on or order "
"outer plan pragma primary query raise recursive references regexp reindex "
"release rename replace restrict right rollback row savepoint select set "
"table temp temporary then to transaction trigger union unique update using "
"vacuum values view virtual when where with without"
).split()
)
SPATIALITE_DOCKERFILE_EXTRAS = r'''
SPATIALITE_DOCKERFILE_EXTRAS = r"""
RUN apt-get update && \
apt-get install -y python3-dev gcc libsqlite3-mod-spatialite && \
rm -rf /var/lib/apt/lists/*
ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so
'''
"""
class InterruptedError(Exception):
@ -67,27 +69,24 @@ class Results:
def urlsafe_components(token):
"Splits token on commas and URL decodes each component"
return [
urllib.parse.unquote_plus(b) for b in token.split(',')
]
return [urllib.parse.unquote_plus(b) for b in token.split(",")]
def path_from_row_pks(row, pks, use_rowid, quote=True):
""" Generate an optionally URL-quoted unique identifier
for a row from its primary keys."""
if use_rowid:
bits = [row['rowid']]
bits = [row["rowid"]]
else:
bits = [
row[pk]["value"] if isinstance(row[pk], dict) else row[pk]
for pk in pks
row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks
]
if quote:
bits = [urllib.parse.quote_plus(str(bit)) for bit in bits]
else:
bits = [str(bit) for bit in bits]
return ','.join(bits)
return ",".join(bits)
def compound_keys_after_sql(pks, start_index=0):
@ -106,16 +105,17 @@ def compound_keys_after_sql(pks, start_index=0):
and_clauses = []
last = pks_left[-1]
rest = pks_left[:-1]
and_clauses = ['{} = :p{}'.format(
escape_sqlite(pk), (i + start_index)
) for i, pk in enumerate(rest)]
and_clauses.append('{} > :p{}'.format(
escape_sqlite(last), (len(rest) + start_index)
))
or_clauses.append('({})'.format(' and '.join(and_clauses)))
and_clauses = [
"{} = :p{}".format(escape_sqlite(pk), (i + start_index))
for i, pk in enumerate(rest)
]
and_clauses.append(
"{} > :p{}".format(escape_sqlite(last), (len(rest) + start_index))
)
or_clauses.append("({})".format(" and ".join(and_clauses)))
pks_left.pop()
or_clauses.reverse()
return '({})'.format('\n or\n'.join(or_clauses))
return "({})".format("\n or\n".join(or_clauses))
class CustomJSONEncoder(json.JSONEncoder):
@ -127,11 +127,11 @@ class CustomJSONEncoder(json.JSONEncoder):
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode('utf8')
return obj.decode("utf8")
except UnicodeDecodeError:
return {
'$base64': True,
'encoded': base64.b64encode(obj).decode('latin1'),
"$base64": True,
"encoded": base64.b64encode(obj).decode("latin1"),
}
return json.JSONEncoder.default(self, obj)
@ -163,20 +163,18 @@ class InvalidSql(Exception):
allowed_sql_res = [
re.compile(r'^select\b'),
re.compile(r'^explain select\b'),
re.compile(r'^explain query plan select\b'),
re.compile(r'^with\b'),
]
disallawed_sql_res = [
(re.compile('pragma'), 'Statement may not contain PRAGMA'),
re.compile(r"^select\b"),
re.compile(r"^explain select\b"),
re.compile(r"^explain query plan select\b"),
re.compile(r"^with\b"),
]
disallawed_sql_res = [(re.compile("pragma"), "Statement may not contain PRAGMA")]
def validate_sql_select(sql):
sql = sql.strip().lower()
if not any(r.match(sql) for r in allowed_sql_res):
raise InvalidSql('Statement must be a SELECT')
raise InvalidSql("Statement must be a SELECT")
for r, msg in disallawed_sql_res:
if r.search(sql):
raise InvalidSql(msg)
@ -184,9 +182,7 @@ def validate_sql_select(sql):
def append_querystring(url, querystring):
op = "&" if ("?" in url) else "?"
return "{}{}{}".format(
url, op, querystring
)
return "{}{}{}".format(url, op, querystring)
def path_with_added_args(request, args, path=None):
@ -198,14 +194,10 @@ def path_with_added_args(request, args, path=None):
for key, value in urllib.parse.parse_qsl(request.query_string):
if key not in args_to_remove:
current.append((key, value))
current.extend([
(key, value)
for key, value in args
if value is not None
])
current.extend([(key, value) for key, value in args if value is not None])
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
@ -220,18 +212,21 @@ def path_with_removed_args(request, args, path=None):
# args can be a dict or a set
current = []
if isinstance(args, set):
def should_remove(key, value):
return key in args
elif isinstance(args, dict):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
for key, value in urllib.parse.parse_qsl(query_string):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
@ -247,54 +242,66 @@ def path_with_replaced_args(request, args, path=None):
current.extend([p for p in args if p[1] is not None])
query_string = urllib.parse.urlencode(current)
if query_string:
query_string = '?{}'.format(query_string)
query_string = "?{}".format(query_string)
return path + query_string
_css_re = re.compile(r'''['"\n\\]''')
_boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
_css_re = re.compile(r"""['"\n\\]""")
_boring_keyword_re = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def escape_css_string(s):
return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s)
return _css_re.sub(lambda m: "\\{:X}".format(ord(m.group())), s)
def escape_sqlite(s):
if _boring_keyword_re.match(s) and (s.lower() not in reserved_words):
return s
else:
return '[{}]'.format(s)
return "[{}]".format(s)
def make_dockerfile(files, metadata_file, extra_options, branch, template_dir, plugins_dir, static, install, spatialite, version_note):
cmd = ['datasette', 'serve', '--host', '0.0.0.0']
def make_dockerfile(
files,
metadata_file,
extra_options,
branch,
template_dir,
plugins_dir,
static,
install,
spatialite,
version_note,
):
cmd = ["datasette", "serve", "--host", "0.0.0.0"]
cmd.append('", "'.join(files))
cmd.extend(['--cors', '--inspect-file', 'inspect-data.json'])
cmd.extend(["--cors", "--inspect-file", "inspect-data.json"])
if metadata_file:
cmd.extend(['--metadata', '{}'.format(metadata_file)])
cmd.extend(["--metadata", "{}".format(metadata_file)])
if template_dir:
cmd.extend(['--template-dir', 'templates/'])
cmd.extend(["--template-dir", "templates/"])
if plugins_dir:
cmd.extend(['--plugins-dir', 'plugins/'])
cmd.extend(["--plugins-dir", "plugins/"])
if version_note:
cmd.extend(['--version-note', '{}'.format(version_note)])
cmd.extend(["--version-note", "{}".format(version_note)])
if static:
for mount_point, _ in static:
cmd.extend(['--static', '{}:{}'.format(mount_point, mount_point)])
cmd.extend(["--static", "{}:{}".format(mount_point, mount_point)])
if extra_options:
for opt in extra_options.split():
cmd.append('{}'.format(opt))
cmd.append("{}".format(opt))
cmd = [shlex.quote(part) for part in cmd]
# port attribute is a (fixed) env variable and should not be quoted
cmd.extend(['--port', '$PORT'])
cmd = ' '.join(cmd)
cmd.extend(["--port", "$PORT"])
cmd = " ".join(cmd)
if branch:
install = ['https://github.com/simonw/datasette/archive/{}.zip'.format(
branch
)] + list(install)
install = [
"https://github.com/simonw/datasette/archive/{}.zip".format(branch)
] + list(install)
else:
install = ['datasette'] + list(install)
install = ["datasette"] + list(install)
return '''
return """
FROM python:3.6
COPY . /app
WORKDIR /app
@ -303,11 +310,11 @@ RUN pip install -U {install_from}
RUN datasette inspect {files} --inspect-file inspect-data.json
ENV PORT 8001
EXPOSE 8001
CMD {cmd}'''.format(
files=' '.join(files),
CMD {cmd}""".format(
files=" ".join(files),
cmd=cmd,
install_from=' '.join(install),
spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else '',
install_from=" ".join(install),
spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else "",
).strip()
@ -324,7 +331,7 @@ def temporary_docker_directory(
install,
spatialite,
version_note,
extra_metadata=None
extra_metadata=None,
):
extra_metadata = extra_metadata or {}
tmp = tempfile.TemporaryDirectory()
@ -332,10 +339,7 @@ def temporary_docker_directory(
datasette_dir = os.path.join(tmp.name, name)
os.mkdir(datasette_dir)
saved_cwd = os.getcwd()
file_paths = [
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
file_names = [os.path.split(f)[-1] for f in files]
if metadata:
metadata_content = json.load(metadata)
@ -347,7 +351,7 @@ def temporary_docker_directory(
try:
dockerfile = make_dockerfile(
file_names,
metadata_content and 'metadata.json',
metadata_content and "metadata.json",
extra_options,
branch,
template_dir,
@ -359,24 +363,23 @@ def temporary_docker_directory(
)
os.chdir(datasette_dir)
if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2))
open('Dockerfile', 'w').write(dockerfile)
open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open("Dockerfile", "w").write(dockerfile)
for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(datasette_dir, filename))
if template_dir:
link_or_copy_directory(
os.path.join(saved_cwd, template_dir),
os.path.join(datasette_dir, 'templates')
os.path.join(datasette_dir, "templates"),
)
if plugins_dir:
link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir),
os.path.join(datasette_dir, 'plugins')
os.path.join(datasette_dir, "plugins"),
)
for mount_point, path in static:
link_or_copy_directory(
os.path.join(saved_cwd, path),
os.path.join(datasette_dir, mount_point)
os.path.join(saved_cwd, path), os.path.join(datasette_dir, mount_point)
)
yield datasette_dir
finally:
@ -396,7 +399,7 @@ def temporary_heroku_directory(
static,
install,
version_note,
extra_metadata=None
extra_metadata=None,
):
# FIXME: lots of duplicated code from above
@ -404,10 +407,7 @@ def temporary_heroku_directory(
tmp = tempfile.TemporaryDirectory()
saved_cwd = os.getcwd()
file_paths = [
os.path.join(saved_cwd, file_path)
for file_path in files
]
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files]
file_names = [os.path.split(f)[-1] for f in files]
if metadata:
@ -422,53 +422,54 @@ def temporary_heroku_directory(
os.chdir(tmp.name)
if metadata_content:
open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2))
open("metadata.json", "w").write(json.dumps(metadata_content, indent=2))
open('runtime.txt', 'w').write('python-3.6.7')
open("runtime.txt", "w").write("python-3.6.7")
if branch:
install = ['https://github.com/simonw/datasette/archive/{branch}.zip'.format(
branch=branch
)] + list(install)
install = [
"https://github.com/simonw/datasette/archive/{branch}.zip".format(
branch=branch
)
] + list(install)
else:
install = ['datasette'] + list(install)
install = ["datasette"] + list(install)
open('requirements.txt', 'w').write('\n'.join(install))
os.mkdir('bin')
open('bin/post_compile', 'w').write('datasette inspect --inspect-file inspect-data.json')
open("requirements.txt", "w").write("\n".join(install))
os.mkdir("bin")
open("bin/post_compile", "w").write(
"datasette inspect --inspect-file inspect-data.json"
)
extras = []
if template_dir:
link_or_copy_directory(
os.path.join(saved_cwd, template_dir),
os.path.join(tmp.name, 'templates')
os.path.join(tmp.name, "templates"),
)
extras.extend(['--template-dir', 'templates/'])
extras.extend(["--template-dir", "templates/"])
if plugins_dir:
link_or_copy_directory(
os.path.join(saved_cwd, plugins_dir),
os.path.join(tmp.name, 'plugins')
os.path.join(saved_cwd, plugins_dir), os.path.join(tmp.name, "plugins")
)
extras.extend(['--plugins-dir', 'plugins/'])
extras.extend(["--plugins-dir", "plugins/"])
if version_note:
extras.extend(['--version-note', version_note])
extras.extend(["--version-note", version_note])
if metadata_content:
extras.extend(['--metadata', 'metadata.json'])
extras.extend(["--metadata", "metadata.json"])
if extra_options:
extras.extend(extra_options.split())
for mount_point, path in static:
link_or_copy_directory(
os.path.join(saved_cwd, path),
os.path.join(tmp.name, mount_point)
os.path.join(saved_cwd, path), os.path.join(tmp.name, mount_point)
)
extras.extend(['--static', '{}:{}'.format(mount_point, mount_point)])
extras.extend(["--static", "{}:{}".format(mount_point, mount_point)])
quoted_files = " ".join(map(shlex.quote, file_names))
procfile_cmd = 'web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}'.format(
quoted_files=quoted_files,
extras=' '.join(extras),
procfile_cmd = "web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}".format(
quoted_files=quoted_files, extras=" ".join(extras)
)
open('Procfile', 'w').write(procfile_cmd)
open("Procfile", "w").write(procfile_cmd)
for path, filename in zip(file_paths, file_names):
link_or_copy(path, os.path.join(tmp.name, filename))
@ -484,9 +485,7 @@ def detect_primary_keys(conn, table):
" Figure out primary keys for a table. "
table_info_rows = [
row
for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
for row in conn.execute('PRAGMA table_info("{}")'.format(table)).fetchall()
if row[-1]
]
table_info_rows.sort(key=lambda row: row[-1])
@ -494,33 +493,26 @@ def detect_primary_keys(conn, table):
def get_outbound_foreign_keys(conn, table):
infos = conn.execute(
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
fks = []
for info in infos:
if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info
fks.append({
'other_table': table_name,
'column': from_,
'other_column': to_
})
fks.append(
{"other_table": table_name, "column": from_, "other_column": to_}
)
return fks
def get_all_foreign_keys(conn):
tables = [r[0] for r in conn.execute('select name from sqlite_master where type="table"')]
tables = [
r[0] for r in conn.execute('select name from sqlite_master where type="table"')
]
table_to_foreign_keys = {}
for table in tables:
table_to_foreign_keys[table] = {
'incoming': [],
'outgoing': [],
}
table_to_foreign_keys[table] = {"incoming": [], "outgoing": []}
for table in tables:
infos = conn.execute(
'PRAGMA foreign_key_list([{}])'.format(table)
).fetchall()
infos = conn.execute("PRAGMA foreign_key_list([{}])".format(table)).fetchall()
for info in infos:
if info is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = info
@ -528,22 +520,20 @@ def get_all_foreign_keys(conn):
# Weird edge case where something refers to a table that does
# not actually exist
continue
table_to_foreign_keys[table_name]['incoming'].append({
'other_table': table,
'column': to_,
'other_column': from_
})
table_to_foreign_keys[table]['outgoing'].append({
'other_table': table_name,
'column': from_,
'other_column': to_
})
table_to_foreign_keys[table_name]["incoming"].append(
{"other_table": table, "column": to_, "other_column": from_}
)
table_to_foreign_keys[table]["outgoing"].append(
{"other_table": table_name, "column": from_, "other_column": to_}
)
return table_to_foreign_keys
def detect_spatialite(conn):
rows = conn.execute('select 1 from sqlite_master where tbl_name = "geometry_columns"').fetchall()
rows = conn.execute(
'select 1 from sqlite_master where tbl_name = "geometry_columns"'
).fetchall()
return len(rows) > 0
@ -557,7 +547,7 @@ def detect_fts(conn, table):
def detect_fts_sql(table):
return r'''
return r"""
select name from sqlite_master
where rootpage = 0
and (
@ -567,7 +557,9 @@ def detect_fts_sql(table):
and sql like '%VIRTUAL TABLE%USING FTS%'
)
)
'''.format(table=table)
""".format(
table=table
)
def detect_json1(conn=None):
@ -589,51 +581,53 @@ def table_columns(conn, table):
]
filter_column_re = re.compile(r'^_filter_column_\d+$')
filter_column_re = re.compile(r"^_filter_column_\d+$")
def filters_should_redirect(special_args):
redirect_params = []
# Handle _filter_column=foo&_filter_op=exact&_filter_value=...
filter_column = special_args.get('_filter_column')
filter_op = special_args.get('_filter_op') or ''
filter_value = special_args.get('_filter_value') or ''
if '__' in filter_op:
filter_op, filter_value = filter_op.split('__', 1)
filter_column = special_args.get("_filter_column")
filter_op = special_args.get("_filter_op") or ""
filter_value = special_args.get("_filter_value") or ""
if "__" in filter_op:
filter_op, filter_value = filter_op.split("__", 1)
if filter_column:
redirect_params.append(
('{}__{}'.format(filter_column, filter_op), filter_value)
("{}__{}".format(filter_column, filter_op), filter_value)
)
for key in ('_filter_column', '_filter_op', '_filter_value'):
for key in ("_filter_column", "_filter_op", "_filter_value"):
if key in special_args:
redirect_params.append((key, None))
# Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello
column_keys = [k for k in special_args if filter_column_re.match(k)]
for column_key in column_keys:
number = column_key.split('_')[-1]
number = column_key.split("_")[-1]
column = special_args[column_key]
op = special_args.get('_filter_op_{}'.format(number)) or 'exact'
value = special_args.get('_filter_value_{}'.format(number)) or ''
if '__' in op:
op, value = op.split('__', 1)
op = special_args.get("_filter_op_{}".format(number)) or "exact"
value = special_args.get("_filter_value_{}".format(number)) or ""
if "__" in op:
op, value = op.split("__", 1)
if column:
redirect_params.append(('{}__{}'.format(column, op), value))
redirect_params.extend([
('_filter_column_{}'.format(number), None),
('_filter_op_{}'.format(number), None),
('_filter_value_{}'.format(number), None),
])
redirect_params.append(("{}__{}".format(column, op), value))
redirect_params.extend(
[
("_filter_column_{}".format(number), None),
("_filter_op_{}".format(number), None),
("_filter_value_{}".format(number), None),
]
)
return redirect_params
whitespace_re = re.compile(r'\s')
whitespace_re = re.compile(r"\s")
def is_url(value):
"Must start with http:// or https:// and contain JUST a URL"
if not isinstance(value, str):
return False
if not value.startswith('http://') and not value.startswith('https://'):
if not value.startswith("http://") and not value.startswith("https://"):
return False
# Any whitespace at all is invalid
if whitespace_re.search(value):
@ -641,8 +635,8 @@ def is_url(value):
return True
css_class_re = re.compile(r'^[a-zA-Z]+[_a-zA-Z0-9-]*$')
css_invalid_chars_re = re.compile(r'[^a-zA-Z0-9_\-]')
css_class_re = re.compile(r"^[a-zA-Z]+[_a-zA-Z0-9-]*$")
css_invalid_chars_re = re.compile(r"[^a-zA-Z0-9_\-]")
def to_css_class(s):
@ -656,16 +650,16 @@ def to_css_class(s):
"""
if css_class_re.match(s):
return s
md5_suffix = hashlib.md5(s.encode('utf8')).hexdigest()[:6]
md5_suffix = hashlib.md5(s.encode("utf8")).hexdigest()[:6]
# Strip leading _, -
s = s.lstrip('_').lstrip('-')
s = s.lstrip("_").lstrip("-")
# Replace any whitespace with hyphens
s = '-'.join(s.split())
s = "-".join(s.split())
# Remove any remaining invalid characters
s = css_invalid_chars_re.sub('', s)
s = css_invalid_chars_re.sub("", s)
# Attach the md5 suffix
bits = [b for b in (s, md5_suffix) if b]
return '-'.join(bits)
return "-".join(bits)
def link_or_copy(src, dst):
@ -689,8 +683,8 @@ def module_from_path(path, name):
# Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html
mod = imp.new_module(name)
mod.__file__ = path
with open(path, 'r') as file:
code = compile(file.read(), path, 'exec', dont_inherit=True)
with open(path, "r") as file:
code = compile(file.read(), path, "exec", dont_inherit=True)
exec(code, mod.__dict__)
return mod
@ -702,37 +696,39 @@ def get_plugins(pm):
static_path = None
templates_path = None
try:
if pkg_resources.resource_isdir(plugin.__name__, 'static'):
static_path = pkg_resources.resource_filename(plugin.__name__, 'static')
if pkg_resources.resource_isdir(plugin.__name__, 'templates'):
templates_path = pkg_resources.resource_filename(plugin.__name__, 'templates')
if pkg_resources.resource_isdir(plugin.__name__, "static"):
static_path = pkg_resources.resource_filename(plugin.__name__, "static")
if pkg_resources.resource_isdir(plugin.__name__, "templates"):
templates_path = pkg_resources.resource_filename(
plugin.__name__, "templates"
)
except (KeyError, ImportError):
# Caused by --plugins_dir= plugins - KeyError/ImportError thrown in Py3.5
pass
plugin_info = {
'name': plugin.__name__,
'static_path': static_path,
'templates_path': templates_path,
"name": plugin.__name__,
"static_path": static_path,
"templates_path": templates_path,
}
distinfo = plugin_to_distinfo.get(plugin)
if distinfo:
plugin_info['version'] = distinfo.version
plugin_info["version"] = distinfo.version
plugins.append(plugin_info)
return plugins
async def resolve_table_and_format(table_and_format, table_exists, allowed_formats=[]):
if '.' in table_and_format:
if "." in table_and_format:
# Check if a table exists with this exact name
it_exists = await table_exists(table_and_format)
if it_exists:
return table_and_format, None
# Check if table ends with a known format
formats = list(allowed_formats) + ['csv', 'jsono']
formats = list(allowed_formats) + ["csv", "jsono"]
for _format in formats:
if table_and_format.endswith(".{}".format(_format)):
table = table_and_format[:-(len(_format) + 1)]
table = table_and_format[: -(len(_format) + 1)]
return table, _format
return table_and_format, None
@ -747,9 +743,7 @@ def path_with_format(request, format, extra_qs=None):
if qs:
extra = urllib.parse.urlencode(sorted(qs.items()))
if request.query_string:
path = "{}?{}&{}".format(
path, request.query_string, extra
)
path = "{}?{}&{}".format(path, request.query_string, extra)
else:
path = "{}?{}".format(path, extra)
elif request.query_string:
@ -777,9 +771,9 @@ class CustomRow(OrderedDict):
def value_as_boolean(value):
if value.lower() not in ('on', 'off', 'true', 'false', '1', '0'):
if value.lower() not in ("on", "off", "true", "false", "1", "0"):
raise ValueAsBooleanError
return value.lower() in ('on', 'true', '1')
return value.lower() in ("on", "true", "1")
class ValueAsBooleanError(ValueError):
@ -799,9 +793,9 @@ class LimitedWriter:
def write(self, bytes):
self.bytes_count += len(bytes)
if self.limit_bytes and (self.bytes_count > self.limit_bytes):
raise WriteLimitExceeded("CSV contains more than {} bytes".format(
self.limit_bytes
))
raise WriteLimitExceeded(
"CSV contains more than {} bytes".format(self.limit_bytes)
)
self.writer.write(bytes)
@ -810,10 +804,7 @@ _infinities = {float("inf"), float("-inf")}
def remove_infinites(row):
if any((c in _infinities) if isinstance(c, float) else 0 for c in row):
return [
None if (isinstance(c, float) and c in _infinities) else c
for c in row
]
return [None if (isinstance(c, float) and c in _infinities) else c for c in row]
return row
@ -824,7 +815,8 @@ class StaticMount(click.ParamType):
if ":" not in value:
self.fail(
'"{}" should be of format mountpoint:directory'.format(value),
param, ctx
param,
ctx,
)
path, dirpath = value.split(":")
if not os.path.exists(dirpath) or not os.path.isdir(dirpath):

View file

@ -1,6 +1,6 @@
from ._version import get_versions
__version__ = get_versions()['version']
__version__ = get_versions()["version"]
del get_versions
__version_info__ = tuple(__version__.split("."))

View file

@ -33,8 +33,15 @@ HASH_LENGTH = 7
class DatasetteError(Exception):
def __init__(self, message, title=None, error_dict=None, status=500, template=None, messagge_is_html=False):
def __init__(
self,
message,
title=None,
error_dict=None,
status=500,
template=None,
messagge_is_html=False,
):
self.message = message
self.title = title
self.error_dict = error_dict or {}
@ -43,18 +50,19 @@ class DatasetteError(Exception):
class RenderMixin(HTTPMethodView):
def _asset_urls(self, key, template, context):
# Flatten list-of-lists from plugins:
seen_urls = set()
for url_or_dict in itertools.chain(
itertools.chain.from_iterable(getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds
)),
(self.ds.metadata(key) or [])
itertools.chain.from_iterable(
getattr(pm.hook, key)(
template=template.name,
database=context.get("database"),
table=context.get("table"),
datasette=self.ds,
)
),
(self.ds.metadata(key) or []),
):
if isinstance(url_or_dict, dict):
url = url_or_dict["url"]
@ -73,14 +81,12 @@ class RenderMixin(HTTPMethodView):
def database_url(self, database):
db = self.ds.databases[database]
if self.ds.config("hash_urls") and db.hash:
return "/{}-{}".format(
database, db.hash[:HASH_LENGTH]
)
return "/{}-{}".format(database, db.hash[:HASH_LENGTH])
else:
return "/{}".format(database)
def database_color(self, database):
return 'ff0000'
return "ff0000"
def render(self, templates, **context):
template = self.ds.jinja_env.select_template(templates)
@ -95,7 +101,7 @@ class RenderMixin(HTTPMethodView):
database=context.get("database"),
table=context.get("table"),
view_name=self.name,
datasette=self.ds
datasette=self.ds,
):
body_scripts.append(jinja2.Markup(script))
return response.html(
@ -116,14 +122,14 @@ class RenderMixin(HTTPMethodView):
"format_bytes": format_bytes,
"database_url": self.database_url,
"database_color": self.database_color,
}
},
}
)
)
class BaseView(RenderMixin):
name = ''
name = ""
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
def __init__(self, datasette):
@ -171,32 +177,30 @@ class BaseView(RenderMixin):
expected = "000"
if db.hash is not None:
expected = db.hash[:HASH_LENGTH]
correct_hash_provided = (expected == hash)
correct_hash_provided = expected == hash
if not correct_hash_provided:
if "table_and_format" in kwargs:
async def async_table_exists(t):
return await self.ds.table_exists(name, t)
table, _format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"]
),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
kwargs["table"] = table
if _format:
kwargs["as_format"] = ".{}".format(_format)
elif "table" in kwargs:
kwargs["table"] = urllib.parse.unquote_plus(
kwargs["table"]
)
kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
should_redirect = "/{}-{}".format(name, expected)
if "table" in kwargs:
should_redirect += "/" + urllib.parse.quote_plus(
kwargs["table"]
)
should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
if "pk_path" in kwargs:
should_redirect += "/" + kwargs["pk_path"]
if "as_format" in kwargs:
@ -219,7 +223,9 @@ class BaseView(RenderMixin):
if should_redirect:
return self.redirect(request, should_redirect, remove_args={"_hash"})
return await self.view_get(request, database, hash, correct_hash_provided, **kwargs)
return await self.view_get(
request, database, hash, correct_hash_provided, **kwargs
)
async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream")
@ -228,9 +234,7 @@ class BaseView(RenderMixin):
if not self.ds.config("allow_csv_stream"):
raise DatasetteError("CSV streaming is disabled", status=400)
if request.args.get("_next"):
raise DatasetteError(
"_next not allowed for CSV streaming", status=400
)
raise DatasetteError("_next not allowed for CSV streaming", status=400)
kwargs["_size"] = "max"
# Fetch the first page
try:
@ -271,9 +275,7 @@ class BaseView(RenderMixin):
if next:
kwargs["_next"] = next
if not first:
data, _, _ = await self.data(
request, database, hash, **kwargs
)
data, _, _ = await self.data(request, database, hash, **kwargs)
if first:
writer.writerow(headings)
first = False
@ -293,7 +295,7 @@ class BaseView(RenderMixin):
new_row.append(cell)
writer.writerow(new_row)
except Exception as e:
print('caught this', e)
print("caught this", e)
r.write(str(e))
return
@ -304,15 +306,11 @@ class BaseView(RenderMixin):
if request.args.get("_dl", None):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', database)
kwargs.get("table", database)
)
headers["Content-Disposition"] = disposition
return response.stream(
stream_fn,
headers=headers,
content_type=content_type
)
return response.stream(stream_fn, headers=headers, content_type=content_type)
async def get_format(self, request, database, args):
""" Determine the format of the response from the request, from URL
@ -325,22 +323,20 @@ class BaseView(RenderMixin):
if not _format:
_format = (args.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in args:
async def async_table_exists(t):
return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus(
args["table_and_format"]
),
table_and_format=urllib.parse.unquote_plus(args["table_and_format"]),
table_exists=async_table_exists,
allowed_formats=self.ds.renderers.keys()
allowed_formats=self.ds.renderers.keys(),
)
_format = _format or _ext_format
args["table"] = table
del args["table_and_format"]
elif "table" in args:
args["table"] = urllib.parse.unquote_plus(
args["table"]
)
args["table"] = urllib.parse.unquote_plus(args["table"])
return _format, args
async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
@ -351,7 +347,7 @@ class BaseView(RenderMixin):
if _format is None:
# HTML views default to expanding all foriegn key labels
kwargs['default_labels'] = True
kwargs["default_labels"] = True
extra_template_data = {}
start = time.time()
@ -367,11 +363,16 @@ class BaseView(RenderMixin):
else:
data, extra_template_data, templates = response_or_template_contexts
except InterruptedError:
raise DatasetteError("""
raise DatasetteError(
"""
SQL query took too long. The time limit is controlled by the
<a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
configuration option.
""", title="SQL Interrupted", status=400, messagge_is_html=True)
""",
title="SQL Interrupted",
status=400,
messagge_is_html=True,
)
except (sqlite3.OperationalError, InvalidSql) as e:
raise DatasetteError(str(e), title="Invalid SQL", status=400)
@ -408,14 +409,14 @@ class BaseView(RenderMixin):
raise NotFound("No data")
response_args = {
'content_type': result.get('content_type', 'text/plain'),
'status': result.get('status_code', 200)
"content_type": result.get("content_type", "text/plain"),
"status": result.get("status_code", 200),
}
if type(result.get('body')) == bytes:
response_args['body_bytes'] = result.get('body')
if type(result.get("body")) == bytes:
response_args["body_bytes"] = result.get("body")
else:
response_args['body'] = result.get('body')
response_args["body"] = result.get("body")
r = response.HTTPResponse(**response_args)
else:
@ -431,14 +432,12 @@ class BaseView(RenderMixin):
url_labels_extra = {"_labels": "on"}
renderers = {
key: path_with_format(request, key, {**url_labels_extra}) for key in self.ds.renderers.keys()
}
url_csv_args = {
"_size": "max",
**url_labels_extra
key: path_with_format(request, key, {**url_labels_extra})
for key in self.ds.renderers.keys()
}
url_csv_args = {"_size": "max", **url_labels_extra}
url_csv = path_with_format(request, "csv", url_csv_args)
url_csv_path = url_csv.split('?')[0]
url_csv_path = url_csv.split("?")[0]
context = {
**data,
**extras,
@ -450,10 +449,11 @@ class BaseView(RenderMixin):
(key, value)
for key, value in urllib.parse.parse_qsl(request.query_string)
if key not in ("_labels", "_facet", "_size")
] + [("_size", "max")],
]
+ [("_size", "max")],
"datasette_version": __version__,
"config": self.ds.config_dict(),
}
},
}
if "metadata" not in context:
context["metadata"] = self.ds.metadata
@ -474,9 +474,9 @@ class BaseView(RenderMixin):
if self.ds.cache_headers and response.status == 200:
ttl = int(ttl)
if ttl == 0:
ttl_header = 'no-cache'
ttl_header = "no-cache"
else:
ttl_header = 'max-age={}'.format(ttl)
ttl_header = "max-age={}".format(ttl)
response.headers["Cache-Control"] = ttl_header
response.headers["Referrer-Policy"] = "no-referrer"
if self.ds.cors:
@ -484,8 +484,15 @@ class BaseView(RenderMixin):
return response
async def custom_sql(
self, request, database, hash, sql, editable=True, canned_query=None,
metadata=None, _size=None
self,
request,
database,
hash,
sql,
editable=True,
canned_query=None,
metadata=None,
_size=None,
):
params = request.raw_args
if "sql" in params:
@ -565,10 +572,14 @@ class BaseView(RenderMixin):
"hide_sql": "_hide_sql" in params,
}
return {
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
}, extra_template, templates
return (
{
"database": database,
"rows": results.rows,
"truncated": results.truncated,
"columns": columns,
"query": {"sql": sql, "params": params},
},
extra_template,
templates,
)

View file

@ -15,7 +15,7 @@ from .base import HASH_LENGTH, RenderMixin
class IndexView(RenderMixin):
name = 'index'
name = "index"
def __init__(self, datasette):
self.ds = datasette
@ -43,23 +43,25 @@ class IndexView(RenderMixin):
}
hidden_tables = [t for t in tables.values() if t["hidden"]]
databases.append({
"name": name,
"hash": db.hash,
"color": db.hash[:6] if db.hash else hashlib.md5(name.encode("utf8")).hexdigest()[:6],
"path": self.database_url(name),
"tables_truncated": sorted(
tables.values(), key=lambda t: t["count"] or 0, reverse=True
)[
:5
],
"tables_count": len(tables),
"tables_more": len(tables) > 5,
"table_rows_sum": sum((t["count"] or 0) for t in tables.values()),
"hidden_table_rows_sum": sum(t["count"] for t in hidden_tables),
"hidden_tables_count": len(hidden_tables),
"views_count": len(views),
})
databases.append(
{
"name": name,
"hash": db.hash,
"color": db.hash[:6]
if db.hash
else hashlib.md5(name.encode("utf8")).hexdigest()[:6],
"path": self.database_url(name),
"tables_truncated": sorted(
tables.values(), key=lambda t: t["count"] or 0, reverse=True
)[:5],
"tables_count": len(tables),
"tables_more": len(tables) > 5,
"table_rows_sum": sum((t["count"] or 0) for t in tables.values()),
"hidden_table_rows_sum": sum(t["count"] for t in hidden_tables),
"hidden_tables_count": len(hidden_tables),
"views_count": len(views),
}
)
if as_format:
headers = {}
if self.ds.cors:

View file

@ -18,14 +18,8 @@ class JsonDataView(RenderMixin):
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
return response.HTTPResponse(
json.dumps(data),
content_type="application/json",
headers=headers
json.dumps(data), content_type="application/json", headers=headers
)
else:
return self.render(
["show_json.html"],
filename=self.filename,
data=data
)
return self.render(["show_json.html"], filename=self.filename, data=data)

View file

@ -31,12 +31,13 @@ from datasette.utils import (
from datasette.filters import Filters
from .base import BaseView, DatasetteError, ureg
LINK_WITH_LABEL = '<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
LINK_WITH_LABEL = (
'<a href="/{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
)
LINK_WITH_VALUE = '<a href="/{database}/{table}/{link_id}">{id}</a>'
class RowTableShared(BaseView):
async def sortable_columns_for_table(self, database, table, use_rowid):
table_metadata = self.ds.table_metadata(database, table)
if "sortable_columns" in table_metadata:
@ -51,18 +52,14 @@ class RowTableShared(BaseView):
# Returns list of (fk_dict, label_column-or-None) pairs for that table
expandables = []
for fk in await self.ds.foreign_keys_for_table(database, table):
label_column = await self.ds.label_column_for_table(database, fk["other_table"])
label_column = await self.ds.label_column_for_table(
database, fk["other_table"]
)
expandables.append((fk, label_column))
return expandables
async def display_columns_and_rows(
self,
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
self, database, table, description, rows, link_column=False, truncate_cells=0
):
"Returns columns, rows for specified table - including fancy foreign key treatment"
table_metadata = self.ds.table_metadata(database, table)
@ -121,8 +118,10 @@ class RowTableShared(BaseView):
if plugin_display_value is not None:
display_value = plugin_display_value
elif isinstance(value, bytes):
display_value = jinja2.Markup("&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format(
len(value), "" if len(value) == 1 else "s")
display_value = jinja2.Markup(
"&lt;Binary&nbsp;data:&nbsp;{}&nbsp;byte{}&gt;".format(
len(value), "" if len(value) == 1 else "s"
)
)
elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row
@ -133,13 +132,15 @@ class RowTableShared(BaseView):
link_template = (
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
)
display_value = jinja2.Markup(link_template.format(
database=database,
table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)),
id=str(jinja2.escape(value)),
label=str(jinja2.escape(label)),
))
display_value = jinja2.Markup(
link_template.format(
database=database,
table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)),
id=str(jinja2.escape(value)),
label=str(jinja2.escape(label)),
)
)
elif value in ("", None):
display_value = jinja2.Markup("&nbsp;")
elif is_url(str(value).strip()):
@ -180,9 +181,18 @@ class RowTableShared(BaseView):
class TableView(RowTableShared):
name = 'table'
name = "table"
async def data(self, request, database, hash, table, default_labels=False, _next=None, _size=None):
async def data(
self,
request,
database,
hash,
table,
default_labels=False,
_next=None,
_size=None,
):
canned_query = self.ds.get_canned_query(database, table)
if canned_query is not None:
return await self.custom_sql(
@ -271,12 +281,13 @@ class TableView(RowTableShared):
raise DatasetteError("_where= is not allowed", status=400)
else:
where_clauses.extend(request.args["_where"])
extra_wheres_for_ui = [{
"text": text,
"remove_url": path_with_removed_args(
request, {"_where": text}
)
} for text in request.args["_where"]]
extra_wheres_for_ui = [
{
"text": text,
"remove_url": path_with_removed_args(request, {"_where": text}),
}
for text in request.args["_where"]
]
# _search support:
fts_table = special_args.get("_fts_table")
@ -296,8 +307,7 @@ class TableView(RowTableShared):
search = search_args["_search"]
where_clauses.append(
"{fts_pk} in (select rowid from {fts_table} where {fts_table} match :search)".format(
fts_table=escape_sqlite(fts_table),
fts_pk=escape_sqlite(fts_pk)
fts_table=escape_sqlite(fts_table), fts_pk=escape_sqlite(fts_pk)
)
)
search_descriptions.append('search matches "{}"'.format(search))
@ -306,14 +316,16 @@ class TableView(RowTableShared):
# More complex: search against specific columns
for i, (key, search_text) in enumerate(search_args.items()):
search_col = key.split("_search_", 1)[1]
if search_col not in await self.ds.table_columns(database, fts_table):
if search_col not in await self.ds.table_columns(
database, fts_table
):
raise DatasetteError("Cannot search by that column", status=400)
where_clauses.append(
"rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format(
fts_table=escape_sqlite(fts_table),
search_col=escape_sqlite(search_col),
i=i
i=i,
)
)
search_descriptions.append(
@ -325,7 +337,9 @@ class TableView(RowTableShared):
sortable_columns = set()
sortable_columns = await self.sortable_columns_for_table(database, table, use_rowid)
sortable_columns = await self.sortable_columns_for_table(
database, table, use_rowid
)
# Allow for custom sort order
sort = special_args.get("_sort")
@ -346,9 +360,9 @@ class TableView(RowTableShared):
from_sql = "from {table_name} {where}".format(
table_name=escape_sqlite(table),
where=(
"where {} ".format(" and ".join(where_clauses))
) if where_clauses else "",
where=("where {} ".format(" and ".join(where_clauses)))
if where_clauses
else "",
)
# Copy of params so we can mutate them later:
from_sql_params = dict(**params)
@ -410,7 +424,9 @@ class TableView(RowTableShared):
column=escape_sqlite(sort or sort_desc),
op=">" if sort else "<",
p=len(params),
extra_desc_only="" if sort else " or {column2} is null".format(
extra_desc_only=""
if sort
else " or {column2} is null".format(
column2=escape_sqlite(sort or sort_desc)
),
next_clauses=" and ".join(next_by_pk_clauses),
@ -470,9 +486,7 @@ class TableView(RowTableShared):
order_by=order_by,
)
sql = "{sql_no_limit} limit {limit}{offset}".format(
sql_no_limit=sql_no_limit.rstrip(),
limit=page_size + 1,
offset=offset,
sql_no_limit=sql_no_limit.rstrip(), limit=page_size + 1, offset=offset
)
if request.raw_args.get("_timelimit"):
@ -486,15 +500,17 @@ class TableView(RowTableShared):
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(await self.ds.execute(
database, count_sql, from_sql_params
))
count_rows = list(
await self.ds.execute(database, count_sql, from_sql_params)
)
filtered_table_rows_count = count_rows[0][0]
except InterruptedError:
pass
# facets support
if not self.ds.config("allow_facet") and any(arg.startswith("_facet") for arg in request.args):
if not self.ds.config("allow_facet") and any(
arg.startswith("_facet") for arg in request.args
):
raise DatasetteError("_facet= is not allowed", status=400)
# pylint: disable=no-member
@ -505,19 +521,23 @@ class TableView(RowTableShared):
facets_timed_out = []
facet_instances = []
for klass in facet_classes:
facet_instances.append(klass(
self.ds,
request,
database,
sql=sql_no_limit,
params=params,
table=table,
metadata=table_metadata,
row_count=filtered_table_rows_count,
))
facet_instances.append(
klass(
self.ds,
request,
database,
sql=sql_no_limit,
params=params,
table=table,
metadata=table_metadata,
row_count=filtered_table_rows_count,
)
)
for facet in facet_instances:
instance_facet_results, instance_facets_timed_out = await facet.facet_results()
instance_facet_results, instance_facets_timed_out = (
await facet.facet_results()
)
facet_results.update(instance_facet_results)
facets_timed_out.extend(instance_facets_timed_out)
@ -542,9 +562,7 @@ class TableView(RowTableShared):
columns_to_expand = request.args["_label"]
if columns_to_expand is None and all_labels:
# expand all columns with foreign keys
columns_to_expand = [
fk["column"] for fk, _ in expandable_columns
]
columns_to_expand = [fk["column"] for fk, _ in expandable_columns]
if columns_to_expand:
expanded_labels = {}
@ -557,9 +575,9 @@ class TableView(RowTableShared):
column_index = columns.index(column)
values = [row[column_index] for row in rows]
# Expand them
expanded_labels.update(await self.ds.expand_foreign_keys(
database, table, column, values
))
expanded_labels.update(
await self.ds.expand_foreign_keys(database, table, column, values)
)
if expanded_labels:
# Rewrite the rows
new_rows = []
@ -569,8 +587,8 @@ class TableView(RowTableShared):
value = row[column]
if (column, value) in expanded_labels:
new_row[column] = {
'value': value,
'label': expanded_labels[(column, value)]
"value": value,
"label": expanded_labels[(column, value)],
}
else:
new_row[column] = value
@ -608,7 +626,11 @@ class TableView(RowTableShared):
# Detect suggested facets
suggested_facets = []
if self.ds.config("suggest_facets") and self.ds.config("allow_facet") and not _next:
if (
self.ds.config("suggest_facets")
and self.ds.config("allow_facet")
and not _next
):
for facet in facet_instances:
# TODO: ensure facet is not suggested if it is already active
# used to use 'if facet_column in facets' for this
@ -634,10 +656,11 @@ class TableView(RowTableShared):
link_column=not is_view,
truncate_cells=self.ds.config("truncate_cells_html"),
)
metadata = (self.ds.metadata("databases") or {}).get(database, {}).get(
"tables", {}
).get(
table, {}
metadata = (
(self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {})
)
self.ds.update_with_inherited_metadata(metadata)
form_hidden_args = []
@ -656,7 +679,7 @@ class TableView(RowTableShared):
"sorted_facet_results": sorted(
facet_results.values(),
key=lambda f: (len(f["results"]), f["name"]),
reverse=True
reverse=True,
),
"extra_wheres_for_ui": extra_wheres_for_ui,
"form_hidden_args": form_hidden_args,
@ -682,32 +705,36 @@ class TableView(RowTableShared):
"table_definition": await self.ds.get_table_definition(database, table),
}
return {
"database": database,
"table": table,
"is_view": is_view,
"human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": results.truncated,
"filtered_table_rows_count": filtered_table_rows_count,
"expanded_columns": expanded_columns,
"expandable_columns": expandable_columns,
"columns": columns,
"primary_keys": pks,
"units": units,
"query": {"sql": sql, "params": params},
"facet_results": facet_results,
"suggested_facets": suggested_facets,
"next": next_value and str(next_value) or None,
"next_url": next_url,
}, extra_template, (
"table-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"table.html",
return (
{
"database": database,
"table": table,
"is_view": is_view,
"human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": results.truncated,
"filtered_table_rows_count": filtered_table_rows_count,
"expanded_columns": expanded_columns,
"expandable_columns": expandable_columns,
"columns": columns,
"primary_keys": pks,
"units": units,
"query": {"sql": sql, "params": params},
"facet_results": facet_results,
"suggested_facets": suggested_facets,
"next": next_value and str(next_value) or None,
"next_url": next_url,
},
extra_template,
(
"table-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"table.html",
),
)
class RowView(RowTableShared):
name = 'row'
name = "row"
async def data(self, request, database, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
@ -720,15 +747,13 @@ class RowView(RowTableShared):
select = "rowid, *"
pks = ["rowid"]
wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)]
sql = 'select {} from {} where {}'.format(
sql = "select {} from {} where {}".format(
select, escape_sqlite(table), " AND ".join(wheres)
)
params = {}
for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value
results = await self.ds.execute(
database, sql, params, truncate=True
)
results = await self.ds.execute(database, sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
if not rows:
@ -760,13 +785,10 @@ class RowView(RowTableShared):
),
"_rows_and_columns.html",
],
"metadata": (
self.ds.metadata("databases") or {}
).get(database, {}).get(
"tables", {}
).get(
table, {}
),
"metadata": (self.ds.metadata("databases") or {})
.get(database, {})
.get("tables", {})
.get(table, {}),
}
data = {
@ -784,8 +806,13 @@ class RowView(RowTableShared):
database, table, pk_values
)
return data, template_data, (
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)), "row.html"
return (
data,
template_data,
(
"row-{}-{}.html".format(to_css_class(database), to_css_class(table)),
"row.html",
),
)
async def foreign_key_tables(self, database, table, pk_values):
@ -801,7 +828,7 @@ class RowView(RowTableShared):
sql = "select " + ", ".join(
[
'(select count(*) from {table} where {column}=:id)'.format(
"(select count(*) from {table} where {column}=:id)".format(
table=escape_sqlite(fk["other_table"]),
column=escape_sqlite(fk["other_column"]),
)
@ -822,8 +849,8 @@ class RowView(RowTableShared):
)
foreign_key_tables = []
for fk in foreign_keys:
count = foreign_table_counts.get(
(fk["other_table"], fk["other_column"])
) or 0
count = (
foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0
)
foreign_key_tables.append({**fk, **{"count": count}})
return foreign_key_tables

View file

@ -17,9 +17,7 @@ class TestClient:
def get(self, path, allow_redirects=True):
return self.sanic_test_client.get(
path,
allow_redirects=allow_redirects,
gather_request=False
path, allow_redirects=allow_redirects, gather_request=False
)
@ -79,39 +77,35 @@ def app_client_no_files():
client.ds = ds
yield client
@pytest.fixture(scope="session")
def app_client_with_memory():
yield from make_app_client(memory=True)
@pytest.fixture(scope="session")
def app_client_with_hash():
yield from make_app_client(config={
'hash_urls': True,
}, is_immutable=True)
yield from make_app_client(config={"hash_urls": True}, is_immutable=True)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_shorter_time_limit():
yield from make_app_client(20)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_returned_rows_matches_page_size():
yield from make_app_client(max_returned_rows=50)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_larger_cache_size():
yield from make_app_client(config={
'cache_size_kb': 2500,
})
yield from make_app_client(config={"cache_size_kb": 2500})
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_csv_max_mb_one():
yield from make_app_client(config={
'max_csv_mb': 1,
})
yield from make_app_client(config={"max_csv_mb": 1})
@pytest.fixture(scope="session")
@ -119,7 +113,7 @@ def app_client_with_dot():
yield from make_app_client(filename="fixtures.dot.db")
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def app_client_with_cors():
yield from make_app_client(cors=True)
@ -128,7 +122,7 @@ def generate_compound_rows(num):
for a, b, c in itertools.islice(
itertools.product(string.ascii_lowercase, repeat=3), num
):
yield a, b, c, '{}-{}-{}'.format(a, b, c)
yield a, b, c, "{}-{}-{}".format(a, b, c)
def generate_sortable_rows(num):
@ -137,107 +131,81 @@ def generate_sortable_rows(num):
itertools.product(string.ascii_lowercase, repeat=2), num
):
yield {
'pk1': a,
'pk2': b,
'content': '{}-{}'.format(a, b),
'sortable': rand.randint(-100, 100),
'sortable_with_nulls': rand.choice([
None, rand.random(), rand.random()
]),
'sortable_with_nulls_2': rand.choice([
None, rand.random(), rand.random()
]),
'text': rand.choice(['$null', '$blah']),
"pk1": a,
"pk2": b,
"content": "{}-{}".format(a, b),
"sortable": rand.randint(-100, 100),
"sortable_with_nulls": rand.choice([None, rand.random(), rand.random()]),
"sortable_with_nulls_2": rand.choice([None, rand.random(), rand.random()]),
"text": rand.choice(["$null", "$blah"]),
}
METADATA = {
'title': 'Datasette Fixtures',
'description': 'An example SQLite database demonstrating Datasette',
'license': 'Apache License 2.0',
'license_url': 'https://github.com/simonw/datasette/blob/master/LICENSE',
'source': 'tests/fixtures.py',
'source_url': 'https://github.com/simonw/datasette/blob/master/tests/fixtures.py',
'about': 'About Datasette',
'about_url': 'https://github.com/simonw/datasette',
"plugins": {
"name-of-plugin": {
"depth": "root"
}
},
'databases': {
'fixtures': {
'description': 'Test tables description',
"plugins": {
"name-of-plugin": {
"depth": "database"
}
},
'tables': {
'simple_primary_key': {
'description_html': 'Simple <em>primary</em> key',
'title': 'This <em>HTML</em> is escaped',
"title": "Datasette Fixtures",
"description": "An example SQLite database demonstrating Datasette",
"license": "Apache License 2.0",
"license_url": "https://github.com/simonw/datasette/blob/master/LICENSE",
"source": "tests/fixtures.py",
"source_url": "https://github.com/simonw/datasette/blob/master/tests/fixtures.py",
"about": "About Datasette",
"about_url": "https://github.com/simonw/datasette",
"plugins": {"name-of-plugin": {"depth": "root"}},
"databases": {
"fixtures": {
"description": "Test tables description",
"plugins": {"name-of-plugin": {"depth": "database"}},
"tables": {
"simple_primary_key": {
"description_html": "Simple <em>primary</em> key",
"title": "This <em>HTML</em> is escaped",
"plugins": {
"name-of-plugin": {
"depth": "table",
"special": "this-is-simple_primary_key"
"special": "this-is-simple_primary_key",
}
}
},
},
'sortable': {
'sortable_columns': [
'sortable',
'sortable_with_nulls',
'sortable_with_nulls_2',
'text',
"sortable": {
"sortable_columns": [
"sortable",
"sortable_with_nulls",
"sortable_with_nulls_2",
"text",
],
"plugins": {
"name-of-plugin": {
"depth": "table"
}
}
"plugins": {"name-of-plugin": {"depth": "table"}},
},
'no_primary_key': {
'sortable_columns': [],
'hidden': True,
"no_primary_key": {"sortable_columns": [], "hidden": True},
"units": {"units": {"distance": "m", "frequency": "Hz"}},
"primary_key_multiple_columns_explicit_label": {
"label_column": "content2"
},
'units': {
'units': {
'distance': 'm',
'frequency': 'Hz'
}
"simple_view": {"sortable_columns": ["content"]},
"searchable_view_configured_by_metadata": {
"fts_table": "searchable_fts",
"fts_pk": "pk",
},
'primary_key_multiple_columns_explicit_label': {
'label_column': 'content2',
},
'simple_view': {
'sortable_columns': ['content'],
},
'searchable_view_configured_by_metadata': {
'fts_table': 'searchable_fts',
'fts_pk': 'pk'
}
},
'queries': {
'pragma_cache_size': 'PRAGMA cache_size;',
'neighborhood_search': {
'sql': '''
"queries": {
"pragma_cache_size": "PRAGMA cache_size;",
"neighborhood_search": {
"sql": """
select neighborhood, facet_cities.name, state
from facetable
join facet_cities
on facetable.city_id = facet_cities.id
where neighborhood like '%' || :text || '%'
order by neighborhood;
''',
'title': 'Search neighborhoods',
'description_html': '<b>Demonstrating</b> simple like search',
""",
"title": "Search neighborhoods",
"description_html": "<b>Demonstrating</b> simple like search",
},
}
},
}
},
}
},
}
PLUGIN1 = '''
PLUGIN1 = """
from datasette import hookimpl
import base64
import pint
@ -304,9 +272,9 @@ def render_cell(value, column, table, database, datasette):
table=table,
)
})
'''
"""
PLUGIN2 = '''
PLUGIN2 = """
from datasette import hookimpl
import jinja2
import json
@ -349,9 +317,10 @@ def render_cell(value, database):
label=jinja2.escape(data["label"] or "") or "&nbsp;"
)
)
'''
"""
TABLES = '''
TABLES = (
"""
CREATE TABLE simple_primary_key (
id varchar(30) primary key,
content text
@ -581,26 +550,42 @@ CREATE VIEW searchable_view AS
CREATE VIEW searchable_view_configured_by_metadata AS
SELECT * from searchable;
''' + '\n'.join([
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(i=i + 1)
for i in range(201)
]) + '\n'.join([
'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format(
a=a, b=b, c=c, content=content
) for a, b, c, content in generate_compound_rows(1001)
]) + '\n'.join([
'''INSERT INTO sortable VALUES (
"""
+ "\n".join(
[
'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format(
i=i + 1
)
for i in range(201)
]
)
+ "\n".join(
[
'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format(
a=a, b=b, c=c, content=content
)
for a, b, c, content in generate_compound_rows(1001)
]
)
+ "\n".join(
[
"""INSERT INTO sortable VALUES (
"{pk1}", "{pk2}", "{content}", {sortable},
{sortable_with_nulls}, {sortable_with_nulls_2}, "{text}");
'''.format(
**row
).replace('None', 'null') for row in generate_sortable_rows(201)
])
TABLE_PARAMETERIZED_SQL = [(
"insert into binary_data (data) values (?);", [b'this is binary data']
)]
""".format(
**row
).replace(
"None", "null"
)
for row in generate_sortable_rows(201)
]
)
)
TABLE_PARAMETERIZED_SQL = [
("insert into binary_data (data) values (?);", [b"this is binary data"])
]
if __name__ == '__main__':
if __name__ == "__main__":
# Can be called with data.db OR data.db metadata.json
db_filename = sys.argv[-1]
metadata_filename = None
@ -615,9 +600,7 @@ if __name__ == '__main__':
conn.execute(sql, params)
print("Test tables written to {}".format(db_filename))
if metadata_filename:
open(metadata_filename, 'w').write(json.dumps(METADATA))
open(metadata_filename, "w").write(json.dumps(METADATA))
print("- metadata written to {}".format(metadata_filename))
else:
print("Usage: {} db_to_write.db [metadata_to_write.json]".format(
sys.argv[0]
))
print("Usage: {} db_to_write.db [metadata_to_write.json]".format(sys.argv[0]))

File diff suppressed because it is too large Load diff

View file

@ -1,22 +1,26 @@
from .fixtures import ( # noqa
from .fixtures import ( # noqa
app_client,
app_client_csv_max_mb_one,
app_client_with_cors
app_client_with_cors,
)
EXPECTED_TABLE_CSV = '''id,content
EXPECTED_TABLE_CSV = """id,content
1,hello
2,world
3,
4,RENDER_CELL_DEMO
'''.replace('\n', '\r\n')
""".replace(
"\n", "\r\n"
)
EXPECTED_CUSTOM_CSV = '''content
EXPECTED_CUSTOM_CSV = """content
hello
world
'''.replace('\n', '\r\n')
""".replace(
"\n", "\r\n"
)
EXPECTED_TABLE_WITH_LABELS_CSV = '''
EXPECTED_TABLE_WITH_LABELS_CSV = """
pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
1,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]"
2,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]"
@ -33,45 +37,47 @@ pk,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags
13,1,1,MI,3,Detroit,Corktown,[]
14,1,1,MI,3,Detroit,Mexicantown,[]
15,2,0,MC,4,Memnonia,Arcadia Planitia,[]
'''.lstrip().replace('\n', '\r\n')
""".lstrip().replace(
"\n", "\r\n"
)
def test_table_csv(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv')
response = app_client.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert not response.headers.get("Access-Control-Allow-Origin")
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_CSV == response.text
def test_table_csv_cors_headers(app_client_with_cors):
response = app_client_with_cors.get('/fixtures/simple_primary_key.csv')
response = app_client_with_cors.get("/fixtures/simple_primary_key.csv")
assert response.status == 200
assert "*" == response.headers["Access-Control-Allow-Origin"]
def test_table_csv_with_labels(app_client):
response = app_client.get('/fixtures/facetable.csv?_labels=1')
response = app_client.get("/fixtures/facetable.csv?_labels=1")
assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text
def test_custom_sql_csv(app_client):
response = app_client.get(
'/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2'
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
)
assert response.status == 200
assert 'text/plain; charset=utf-8' == response.headers['Content-Type']
assert "text/plain; charset=utf-8" == response.headers["Content-Type"]
assert EXPECTED_CUSTOM_CSV == response.text
def test_table_csv_download(app_client):
response = app_client.get('/fixtures/simple_primary_key.csv?_dl=1')
response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1")
assert response.status == 200
assert 'text/csv; charset=utf-8' == response.headers['Content-Type']
assert "text/csv; charset=utf-8" == response.headers["Content-Type"]
expected_disposition = 'attachment; filename="simple_primary_key.csv"'
assert expected_disposition == response.headers['Content-Disposition']
assert expected_disposition == response.headers["Content-Disposition"]
def test_max_csv_mb(app_client_csv_max_mb_one):
@ -88,12 +94,8 @@ def test_max_csv_mb(app_client_csv_max_mb_one):
def test_table_csv_stream(app_client):
# Without _stream should return header + 100 rows:
response = app_client.get(
"/fixtures/compound_three_primary_keys.csv?_size=max"
)
response = app_client.get("/fixtures/compound_three_primary_keys.csv?_size=max")
assert 101 == len([b for b in response.body.split(b"\r\n") if b])
# With _stream=1 should return header + 1001 rows
response = app_client.get(
"/fixtures/compound_three_primary_keys.csv?_stream=1"
)
response = app_client.get("/fixtures/compound_three_primary_keys.csv?_stream=1")
assert 1002 == len([b for b in response.body.split(b"\r\n") if b])

View file

@ -9,13 +9,13 @@ from pathlib import Path
import pytest
import re
docs_path = Path(__file__).parent.parent / 'docs'
label_re = re.compile(r'\.\. _([^\s:]+):')
docs_path = Path(__file__).parent.parent / "docs"
label_re = re.compile(r"\.\. _([^\s:]+):")
def get_headings(filename, underline="-"):
content = (docs_path / filename).open().read()
heading_re = re.compile(r'(\w+)(\([^)]*\))?\n\{}+\n'.format(underline))
heading_re = re.compile(r"(\w+)(\([^)]*\))?\n\{}+\n".format(underline))
return set(h[0] for h in heading_re.findall(content))
@ -24,38 +24,37 @@ def get_labels(filename):
return set(label_re.findall(content))
@pytest.mark.parametrize('config', app.CONFIG_OPTIONS)
@pytest.mark.parametrize("config", app.CONFIG_OPTIONS)
def test_config_options_are_documented(config):
assert config.name in get_headings("config.rst")
@pytest.mark.parametrize("name,filename", (
("serve", "datasette-serve-help.txt"),
("package", "datasette-package-help.txt"),
("publish now", "datasette-publish-now-help.txt"),
("publish heroku", "datasette-publish-heroku-help.txt"),
("publish cloudrun", "datasette-publish-cloudrun-help.txt"),
))
@pytest.mark.parametrize(
"name,filename",
(
("serve", "datasette-serve-help.txt"),
("package", "datasette-package-help.txt"),
("publish now", "datasette-publish-now-help.txt"),
("publish heroku", "datasette-publish-heroku-help.txt"),
("publish cloudrun", "datasette-publish-cloudrun-help.txt"),
),
)
def test_help_includes(name, filename):
expected = open(str(docs_path / filename)).read()
runner = CliRunner()
result = runner.invoke(cli, name.split() + ["--help"], terminal_width=88)
actual = "$ datasette {} --help\n\n{}".format(
name, result.output
)
actual = "$ datasette {} --help\n\n{}".format(name, result.output)
# actual has "Usage: cli package [OPTIONS] FILES"
# because it doesn't know that cli will be aliased to datasette
expected = expected.replace("Usage: datasette", "Usage: cli")
assert expected == actual
@pytest.mark.parametrize('plugin', [
name for name in dir(app.pm.hook) if not name.startswith('_')
])
@pytest.mark.parametrize(
"plugin", [name for name in dir(app.pm.hook) if not name.startswith("_")]
)
def test_plugin_hooks_are_documented(plugin):
headings = [
s.split("(")[0] for s in get_headings("plugins.rst", "~")
]
headings = [s.split("(")[0] for s in get_headings("plugins.rst", "~")]
assert plugin in headings

View file

@ -2,102 +2,57 @@ from datasette.filters import Filters
import pytest
@pytest.mark.parametrize('args,expected_where,expected_params', [
(
@pytest.mark.parametrize(
"args,expected_where,expected_params",
[
((("name_english__contains", "foo"),), ['"name_english" like :p0'], ["%foo%"]),
(
('name_english__contains', 'foo'),
(("foo", "bar"), ("bar__contains", "baz")),
['"bar" like :p0', '"foo" = :p1'],
["%baz%", "bar"],
),
['"name_english" like :p0'],
['%foo%']
),
(
(
('foo', 'bar'),
('bar__contains', 'baz'),
(("foo__startswith", "bar"), ("bar__endswith", "baz")),
['"bar" like :p0', '"foo" like :p1'],
["%baz", "bar%"],
),
['"bar" like :p0', '"foo" = :p1'],
['%baz%', 'bar']
),
(
(
('foo__startswith', 'bar'),
('bar__endswith', 'baz'),
(("foo__lt", "1"), ("bar__gt", "2"), ("baz__gte", "3"), ("bax__lte", "4")),
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1],
),
['"bar" like :p0', '"foo" like :p1'],
['%baz', 'bar%']
),
(
(
('foo__lt', '1'),
('bar__gt', '2'),
('baz__gte', '3'),
('bax__lte', '4'),
(("foo__like", "2%2"), ("zax__glob", "3*")),
['"foo" like :p0', '"zax" glob :p1'],
["2%2", "3*"],
),
['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'],
[2, 4, 3, 1]
),
(
# Multiple like arguments:
(
('foo__like', '2%2'),
('zax__glob', '3*'),
(("foo__like", "2%2"), ("foo__like", "3%3")),
['"foo" like :p0', '"foo" like :p1'],
["2%2", "3%3"],
),
['"foo" like :p0', '"zax" glob :p1'],
['2%2', '3*']
),
# Multiple like arguments:
(
(
('foo__like', '2%2'),
('foo__like', '3%3'),
(("foo__isnull", "1"), ("baz__isnull", "1"), ("bar__gt", "10")),
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10],
),
['"foo" like :p0', '"foo" like :p1'],
['2%2', '3%3']
),
(
((("foo__in", "1,2,3"),), ["foo in (:p0, :p1, :p2)"], ["1", "2", "3"]),
# date
((("foo__date", "1988-01-01"),), ["date(foo) = :p0"], ["1988-01-01"]),
# JSON array variants of __in (useful for unexpected characters)
((("foo__in", "[1,2,3]"),), ["foo in (:p0, :p1, :p2)"], [1, 2, 3]),
(
('foo__isnull', '1'),
('baz__isnull', '1'),
('bar__gt', '10'),
(("foo__in", '["dog,cat", "cat[dog]"]'),),
["foo in (:p0, :p1)"],
["dog,cat", "cat[dog]"],
),
['"bar" > :p0', '"baz" is null', '"foo" is null'],
[10]
),
(
(
('foo__in', '1,2,3'),
),
['foo in (:p0, :p1, :p2)'],
["1", "2", "3"]
),
# date
(
(
("foo__date", "1988-01-01"),
),
["date(foo) = :p0"],
["1988-01-01"]
),
# JSON array variants of __in (useful for unexpected characters)
(
(
('foo__in', '[1,2,3]'),
),
['foo in (:p0, :p1, :p2)'],
[1, 2, 3]
),
(
(
('foo__in', '["dog,cat", "cat[dog]"]'),
),
['foo in (:p0, :p1)'],
["dog,cat", "cat[dog]"]
),
])
],
)
def test_build_where(args, expected_where, expected_params):
f = Filters(sorted(args))
sql_bits, actual_params = f.build_where_clauses("table")
assert expected_where == sql_bits
assert {
'p{}'.format(i): param
for i, param in enumerate(expected_params)
"p{}".format(i): param for i, param in enumerate(expected_params)
} == actual_params

File diff suppressed because it is too large Load diff

View file

@ -5,7 +5,7 @@ import pytest
import tempfile
TABLES = '''
TABLES = """
CREATE TABLE "election_results" (
"county" INTEGER,
"party" INTEGER,
@ -32,13 +32,13 @@ CREATE TABLE "office" (
"id" INTEGER PRIMARY KEY ,
"name" TEXT
);
'''
"""
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def ds_instance():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, 'fixtures.db')
filepath = os.path.join(tmpdir, "fixtures.db")
conn = sqlite3.connect(filepath)
conn.executescript(TABLES)
yield Datasette([filepath])
@ -46,58 +46,47 @@ def ds_instance():
def test_inspect_hidden_tables(ds_instance):
info = ds_instance.inspect()
tables = info['fixtures']['tables']
tables = info["fixtures"]["tables"]
expected_hidden = (
'election_results_fts',
'election_results_fts_content',
'election_results_fts_docsize',
'election_results_fts_segdir',
'election_results_fts_segments',
'election_results_fts_stat',
)
expected_visible = (
'election_results',
'county',
'party',
'office',
"election_results_fts",
"election_results_fts_content",
"election_results_fts_docsize",
"election_results_fts_segdir",
"election_results_fts_segments",
"election_results_fts_stat",
)
expected_visible = ("election_results", "county", "party", "office")
assert sorted(expected_hidden) == sorted(
[table for table in tables if tables[table]['hidden']]
[table for table in tables if tables[table]["hidden"]]
)
assert sorted(expected_visible) == sorted(
[table for table in tables if not tables[table]['hidden']]
[table for table in tables if not tables[table]["hidden"]]
)
def test_inspect_foreign_keys(ds_instance):
info = ds_instance.inspect()
tables = info['fixtures']['tables']
for table_name in ('county', 'party', 'office'):
assert 0 == tables[table_name]['count']
foreign_keys = tables[table_name]['foreign_keys']
assert [] == foreign_keys['outgoing']
assert [{
'column': 'id',
'other_column': table_name,
'other_table': 'election_results'
}] == foreign_keys['incoming']
tables = info["fixtures"]["tables"]
for table_name in ("county", "party", "office"):
assert 0 == tables[table_name]["count"]
foreign_keys = tables[table_name]["foreign_keys"]
assert [] == foreign_keys["outgoing"]
assert [
{
"column": "id",
"other_column": table_name,
"other_table": "election_results",
}
] == foreign_keys["incoming"]
election_results = tables['election_results']
assert 0 == election_results['count']
assert sorted([{
'column': 'county',
'other_column': 'id',
'other_table': 'county'
}, {
'column': 'party',
'other_column': 'id',
'other_table': 'party'
}, {
'column': 'office',
'other_column': 'id',
'other_table': 'office'
}], key=lambda d: d['column']) == sorted(
election_results['foreign_keys']['outgoing'],
key=lambda d: d['column']
)
assert [] == election_results['foreign_keys']['incoming']
election_results = tables["election_results"]
assert 0 == election_results["count"]
assert sorted(
[
{"column": "county", "other_column": "id", "other_table": "county"},
{"column": "party", "other_column": "id", "other_table": "party"},
{"column": "office", "other_column": "id", "other_table": "office"},
],
key=lambda d: d["column"],
) == sorted(election_results["foreign_keys"]["outgoing"], key=lambda d: d["column"])
assert [] == election_results["foreign_keys"]["incoming"]

View file

@ -1,7 +1,5 @@
from bs4 import BeautifulSoup as Soup
from .fixtures import ( # noqa
app_client,
)
from .fixtures import app_client # noqa
import base64
import json
import re
@ -13,41 +11,26 @@ def test_plugins_dir_plugin(app_client):
response = app_client.get(
"/fixtures.json?sql=select+convert_units(100%2C+'m'%2C+'ft')"
)
assert pytest.approx(328.0839) == response.json['rows'][0][0]
assert pytest.approx(328.0839) == response.json["rows"][0][0]
@pytest.mark.parametrize(
"path,expected_decoded_object",
[
(
"/",
{
"template": "index.html",
"database": None,
"table": None,
},
),
("/", {"template": "index.html", "database": None, "table": None}),
(
"/fixtures/",
{
"template": "database.html",
"database": "fixtures",
"table": None,
},
{"template": "database.html", "database": "fixtures", "table": None},
),
(
"/fixtures/sortable",
{
"template": "table.html",
"database": "fixtures",
"table": "sortable",
},
{"template": "table.html", "database": "fixtures", "table": "sortable"},
),
],
)
def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
response = app_client.get(path)
links = Soup(response.body, 'html.parser').findAll('link')
links = Soup(response.body, "html.parser").findAll("link")
special_href = [
l for l in links if l.attrs["href"].endswith("/extra-css-urls-demo.css")
][0]["href"]
@ -59,47 +42,43 @@ def test_plugin_extra_css_urls(app_client, path, expected_decoded_object):
def test_plugin_extra_js_urls(app_client):
response = app_client.get('/')
scripts = Soup(response.body, 'html.parser').findAll('script')
response = app_client.get("/")
scripts = Soup(response.body, "html.parser").findAll("script")
assert [
s for s in scripts
if s.attrs == {
'integrity': 'SRIHASH',
'crossorigin': 'anonymous',
'src': 'https://example.com/jquery.js'
s
for s in scripts
if s.attrs
== {
"integrity": "SRIHASH",
"crossorigin": "anonymous",
"src": "https://example.com/jquery.js",
}
]
def test_plugins_with_duplicate_js_urls(app_client):
# If two plugins both require jQuery, jQuery should be loaded only once
response = app_client.get(
"/fixtures"
)
response = app_client.get("/fixtures")
# This test is a little tricky, as if the user has any other plugins in
# their current virtual environment those may affect what comes back too.
# What matters is that https://example.com/jquery.js is only there once
# and it comes before plugin1.js and plugin2.js which could be in either
# order
scripts = Soup(response.body, 'html.parser').findAll('script')
srcs = [s['src'] for s in scripts if s.get('src')]
scripts = Soup(response.body, "html.parser").findAll("script")
srcs = [s["src"] for s in scripts if s.get("src")]
# No duplicates allowed:
assert len(srcs) == len(set(srcs))
# jquery.js loaded once:
assert 1 == srcs.count('https://example.com/jquery.js')
assert 1 == srcs.count("https://example.com/jquery.js")
# plugin1.js and plugin2.js are both there:
assert 1 == srcs.count('https://example.com/plugin1.js')
assert 1 == srcs.count('https://example.com/plugin2.js')
assert 1 == srcs.count("https://example.com/plugin1.js")
assert 1 == srcs.count("https://example.com/plugin2.js")
# jquery comes before them both
assert srcs.index(
'https://example.com/jquery.js'
) < srcs.index(
'https://example.com/plugin1.js'
assert srcs.index("https://example.com/jquery.js") < srcs.index(
"https://example.com/plugin1.js"
)
assert srcs.index(
'https://example.com/jquery.js'
) < srcs.index(
'https://example.com/plugin2.js'
assert srcs.index("https://example.com/jquery.js") < srcs.index(
"https://example.com/plugin2.js"
)
@ -107,13 +86,9 @@ def test_plugins_render_cell_link_from_json(app_client):
sql = """
select '{"href": "http://example.com/", "label":"Example"}'
""".strip()
path = "/fixtures?" + urllib.parse.urlencode({
"sql": sql,
})
path = "/fixtures?" + urllib.parse.urlencode({"sql": sql})
response = app_client.get(path)
td = Soup(
response.body, "html.parser"
).find("table").find("tbody").find("td")
td = Soup(response.body, "html.parser").find("table").find("tbody").find("td")
a = td.find("a")
assert a is not None, str(a)
assert a.attrs["href"] == "http://example.com/"
@ -129,10 +104,7 @@ def test_plugins_render_cell_demo(app_client):
"column": "content",
"table": "simple_primary_key",
"database": "fixtures",
"config": {
"depth": "table",
"special": "this-is-simple_primary_key"
}
"config": {"depth": "table", "special": "this-is-simple_primary_key"},
} == json.loads(td.string)

View file

@ -35,7 +35,14 @@ def test_publish_cloudrun(mock_call, mock_output, mock_which):
result = runner.invoke(cli.cli, ["publish", "cloudrun", "test.db"])
assert 0 == result.exit_code
tag = "gcr.io/{}/datasette".format(mock_output.return_value)
mock_call.assert_has_calls([
mock.call("gcloud builds submit --tag {}".format(tag), shell=True),
mock.call("gcloud beta run deploy --allow-unauthenticated --image {}".format(tag), shell=True)])
mock_call.assert_has_calls(
[
mock.call("gcloud builds submit --tag {}".format(tag), shell=True),
mock.call(
"gcloud beta run deploy --allow-unauthenticated --image {}".format(
tag
),
shell=True,
),
]
)

View file

@ -57,7 +57,9 @@ def test_publish_heroku(mock_call, mock_check_output, mock_which):
open("test.db", "w").write("data")
result = runner.invoke(cli.cli, ["publish", "heroku", "test.db"])
assert 0 == result.exit_code, result.output
mock_call.assert_called_once_with(["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"])
mock_call.assert_called_once_with(
["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"]
)
@mock.patch("shutil.which")

View file

@ -13,72 +13,78 @@ import tempfile
from unittest.mock import patch
@pytest.mark.parametrize('path,expected', [
('foo', ['foo']),
('foo,bar', ['foo', 'bar']),
('123,433,112', ['123', '433', '112']),
('123%2C433,112', ['123,433', '112']),
('123%2F433%2F112', ['123/433/112']),
])
@pytest.mark.parametrize(
"path,expected",
[
("foo", ["foo"]),
("foo,bar", ["foo", "bar"]),
("123,433,112", ["123", "433", "112"]),
("123%2C433,112", ["123,433", "112"]),
("123%2F433%2F112", ["123/433/112"]),
],
)
def test_urlsafe_components(path, expected):
assert expected == utils.urlsafe_components(path)
@pytest.mark.parametrize('path,added_args,expected', [
('/foo', {'bar': 1}, '/foo?bar=1'),
('/foo?bar=1', {'baz': 2}, '/foo?bar=1&baz=2'),
('/foo?bar=1&bar=2', {'baz': 3}, '/foo?bar=1&bar=2&baz=3'),
('/foo?bar=1', {'bar': None}, '/foo'),
# Test order is preserved
('/?_facet=prim_state&_facet=area_name', (
('prim_state', 'GA'),
), '/?_facet=prim_state&_facet=area_name&prim_state=GA'),
('/?_facet=state&_facet=city&state=MI', (
('city', 'Detroit'),
), '/?_facet=state&_facet=city&state=MI&city=Detroit'),
('/?_facet=state&_facet=city', (
('_facet', 'planet_int'),
), '/?_facet=state&_facet=city&_facet=planet_int'),
])
@pytest.mark.parametrize(
"path,added_args,expected",
[
("/foo", {"bar": 1}, "/foo?bar=1"),
("/foo?bar=1", {"baz": 2}, "/foo?bar=1&baz=2"),
("/foo?bar=1&bar=2", {"baz": 3}, "/foo?bar=1&bar=2&baz=3"),
("/foo?bar=1", {"bar": None}, "/foo"),
# Test order is preserved
(
"/?_facet=prim_state&_facet=area_name",
(("prim_state", "GA"),),
"/?_facet=prim_state&_facet=area_name&prim_state=GA",
),
(
"/?_facet=state&_facet=city&state=MI",
(("city", "Detroit"),),
"/?_facet=state&_facet=city&state=MI&city=Detroit",
),
(
"/?_facet=state&_facet=city",
(("_facet", "planet_int"),),
"/?_facet=state&_facet=city&_facet=planet_int",
),
],
)
def test_path_with_added_args(path, added_args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_added_args(request, added_args)
assert expected == actual
@pytest.mark.parametrize('path,args,expected', [
('/foo?bar=1', {'bar'}, '/foo'),
('/foo?bar=1&baz=2', {'bar'}, '/foo?baz=2'),
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
])
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar"}, "/foo"),
("/foo?bar=1&baz=2", {"bar"}, "/foo?baz=2"),
("/foo?bar=1&bar=2&bar=3", {"bar": "2"}, "/foo?bar=1&bar=3"),
],
)
def test_path_with_removed_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_removed_args(request, args)
assert expected == actual
# Run the test again but this time use the path= argument
request = Request(
"/".encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request("/".encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_removed_args(request, args, path=path)
assert expected == actual
@pytest.mark.parametrize('path,args,expected', [
('/foo?bar=1', {'bar': 2}, '/foo?bar=2'),
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
])
@pytest.mark.parametrize(
"path,args,expected",
[
("/foo?bar=1", {"bar": 2}, "/foo?bar=2"),
("/foo?bar=1&baz=2", {"bar": None}, "/foo?baz=2"),
],
)
def test_path_with_replaced_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_replaced_args(request, args)
assert expected == actual
@ -93,17 +99,8 @@ def test_path_with_replaced_args(path, args, expected):
utils.CustomRow(
["searchable_id", "tag"],
[
(
"searchable_id",
{"value": 1, "label": "1"},
),
(
"tag",
{
"value": "feline",
"label": "feline",
},
),
("searchable_id", {"value": 1, "label": "1"}),
("tag", {"value": "feline", "label": "feline"}),
],
),
["searchable_id", "tag"],
@ -116,47 +113,54 @@ def test_path_from_row_pks(row, pks, expected_path):
assert expected_path == actual_path
@pytest.mark.parametrize('obj,expected', [
({
'Description': 'Soft drinks',
'Picture': b"\x15\x1c\x02\xc7\xad\x05\xfe",
'CategoryID': 1,
}, """
@pytest.mark.parametrize(
"obj,expected",
[
(
{
"Description": "Soft drinks",
"Picture": b"\x15\x1c\x02\xc7\xad\x05\xfe",
"CategoryID": 1,
},
"""
{"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}}
""".strip()),
])
""".strip(),
)
],
)
def test_custom_json_encoder(obj, expected):
actual = json.dumps(
obj,
cls=utils.CustomJSONEncoder,
sort_keys=True
)
actual = json.dumps(obj, cls=utils.CustomJSONEncoder, sort_keys=True)
assert expected == actual
@pytest.mark.parametrize('bad_sql', [
'update blah;',
'PRAGMA case_sensitive_like = true'
"SELECT * FROM pragma_index_info('idx52')",
])
@pytest.mark.parametrize(
"bad_sql",
[
"update blah;",
"PRAGMA case_sensitive_like = true" "SELECT * FROM pragma_index_info('idx52')",
],
)
def test_validate_sql_select_bad(bad_sql):
with pytest.raises(utils.InvalidSql):
utils.validate_sql_select(bad_sql)
@pytest.mark.parametrize('good_sql', [
'select count(*) from airports',
'select foo from bar',
'select 1 + 1',
'SELECT\nblah FROM foo',
'WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;'
])
@pytest.mark.parametrize(
"good_sql",
[
"select count(*) from airports",
"select foo from bar",
"select 1 + 1",
"SELECT\nblah FROM foo",
"WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;",
],
)
def test_validate_sql_select_good(good_sql):
utils.validate_sql_select(good_sql)
def test_detect_fts():
sql = '''
sql = """
CREATE TABLE "Dumb_Table" (
"TreeID" INTEGER,
"qSpecies" TEXT
@ -173,34 +177,40 @@ def test_detect_fts():
CREATE VIEW Test_View AS SELECT * FROM Dumb_Table;
CREATE VIRTUAL TABLE "Street_Tree_List_fts" USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content="Street_Tree_List");
CREATE VIRTUAL TABLE r USING rtree(a, b, c);
'''
conn = utils.sqlite3.connect(':memory:')
"""
conn = utils.sqlite3.connect(":memory:")
conn.executescript(sql)
assert None is utils.detect_fts(conn, 'Dumb_Table')
assert None is utils.detect_fts(conn, 'Test_View')
assert None is utils.detect_fts(conn, 'r')
assert 'Street_Tree_List_fts' == utils.detect_fts(conn, 'Street_Tree_List')
assert None is utils.detect_fts(conn, "Dumb_Table")
assert None is utils.detect_fts(conn, "Test_View")
assert None is utils.detect_fts(conn, "r")
assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List")
@pytest.mark.parametrize('url,expected', [
('http://www.google.com/', True),
('https://example.com/', True),
('www.google.com', False),
('http://www.google.com/ is a search engine', False),
])
@pytest.mark.parametrize(
"url,expected",
[
("http://www.google.com/", True),
("https://example.com/", True),
("www.google.com", False),
("http://www.google.com/ is a search engine", False),
],
)
def test_is_url(url, expected):
assert expected == utils.is_url(url)
@pytest.mark.parametrize('s,expected', [
('simple', 'simple'),
('MixedCase', 'MixedCase'),
('-no-leading-hyphens', 'no-leading-hyphens-65bea6'),
('_no-leading-underscores', 'no-leading-underscores-b921bc'),
('no spaces', 'no-spaces-7088d7'),
('-', '336d5e'),
('no $ characters', 'no--characters-59e024'),
])
@pytest.mark.parametrize(
"s,expected",
[
("simple", "simple"),
("MixedCase", "MixedCase"),
("-no-leading-hyphens", "no-leading-hyphens-65bea6"),
("_no-leading-underscores", "no-leading-underscores-b921bc"),
("no spaces", "no-spaces-7088d7"),
("-", "336d5e"),
("no $ characters", "no--characters-59e024"),
],
)
def test_to_css_class(s, expected):
assert expected == utils.to_css_class(s)
@ -208,11 +218,11 @@ def test_to_css_class(s, expected):
def test_temporary_docker_directory_uses_hard_link():
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options=None,
branch=None,
@ -223,23 +233,23 @@ def test_temporary_docker_directory_uses_hard_link():
spatialite=False,
version_note=None,
) as temp_docker:
hello = os.path.join(temp_docker, 'hello')
assert 'world' == open(hello).read()
hello = os.path.join(temp_docker, "hello")
assert "world" == open(hello).read()
# It should be a hard link
assert 2 == os.stat(hello).st_nlink
@patch('os.link')
@patch("os.link")
def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
# Copy instead if os.link raises OSError (normally due to different device)
mock_link.side_effect = OSError
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
# Default usage of this should use symlink
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options=None,
branch=None,
@ -250,49 +260,53 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
spatialite=False,
version_note=None,
) as temp_docker:
hello = os.path.join(temp_docker, 'hello')
assert 'world' == open(hello).read()
hello = os.path.join(temp_docker, "hello")
assert "world" == open(hello).read()
# It should be a copy, not a hard link
assert 1 == os.stat(hello).st_nlink
def test_temporary_docker_directory_quotes_args():
with tempfile.TemporaryDirectory() as td:
with tempfile.TemporaryDirectory() as td:
os.chdir(td)
open('hello', 'w').write('world')
open("hello", "w").write("world")
with utils.temporary_docker_directory(
files=['hello'],
name='t',
files=["hello"],
name="t",
metadata=None,
extra_options='--$HOME',
extra_options="--$HOME",
branch=None,
template_dir=None,
plugins_dir=None,
static=[],
install=[],
spatialite=False,
version_note='$PWD',
version_note="$PWD",
) as temp_docker:
df = os.path.join(temp_docker, 'Dockerfile')
df = os.path.join(temp_docker, "Dockerfile")
df_contents = open(df).read()
assert "'$PWD'" in df_contents
assert "'--$HOME'" in df_contents
def test_compound_keys_after_sql():
assert '((a > :p0))' == utils.compound_keys_after_sql(['a'])
assert '''
assert "((a > :p0))" == utils.compound_keys_after_sql(["a"])
assert """
((a > :p0)
or
(a = :p0 and b > :p1))
'''.strip() == utils.compound_keys_after_sql(['a', 'b'])
assert '''
""".strip() == utils.compound_keys_after_sql(
["a", "b"]
)
assert """
((a > :p0)
or
(a = :p0 and b > :p1)
or
(a = :p0 and b = :p1 and c > :p2))
'''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])
""".strip() == utils.compound_keys_after_sql(
["a", "b", "c"]
)
async def table_exists(table):
@ -314,7 +328,7 @@ async def test_resolve_table_and_format(
table_and_format, expected_table, expected_format
):
actual_table, actual_format = await utils.resolve_table_and_format(
table_and_format, table_exists, ['json']
table_and_format, table_exists, ["json"]
)
assert expected_table == actual_table
assert expected_format == actual_format
@ -322,9 +336,11 @@ async def test_resolve_table_and_format(
def test_table_columns():
conn = sqlite3.connect(":memory:")
conn.executescript("""
conn.executescript(
"""
create table places (id integer primary key, name text, bob integer)
""")
"""
)
assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
@ -347,10 +363,7 @@ def test_table_columns():
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
request = Request(path.encode("utf8"), {}, "1.1", "GET", None)
actual = utils.path_with_format(request, format, extra_qs)
assert expected == actual
@ -358,13 +371,13 @@ def test_path_with_format(path, format, extra_qs, expected):
@pytest.mark.parametrize(
"bytes,expected",
[
(120, '120 bytes'),
(1024, '1.0 KB'),
(1024 * 1024, '1.0 MB'),
(1024 * 1024 * 1024, '1.0 GB'),
(1024 * 1024 * 1024 * 1.3, '1.3 GB'),
(1024 * 1024 * 1024 * 1024, '1.0 TB'),
]
(120, "120 bytes"),
(1024, "1.0 KB"),
(1024 * 1024, "1.0 MB"),
(1024 * 1024 * 1024, "1.0 GB"),
(1024 * 1024 * 1024 * 1.3, "1.3 GB"),
(1024 * 1024 * 1024 * 1024, "1.0 TB"),
],
)
def test_format_bytes(bytes, expected):
assert expected == utils.format_bytes(bytes)