diff --git a/datasette/templates/_rows_and_columns.html b/datasette/templates/_rows_and_columns.html
index c7a72253..06ac9bd5 100644
--- a/datasette/templates/_rows_and_columns.html
+++ b/datasette/templates/_rows_and_columns.html
@@ -7,9 +7,9 @@
{{ column.name }}
{% else %}
{% if column.name == sort %}
- {{ column.name }} ▼
+ {{ column.name }} ▼
{% else %}
- {{ column.name }}{% if column.name == sort_desc %} ▲{% endif %}
+ {{ column.name }}{% if column.name == sort_desc %} ▲{% endif %}
{% endif %}
{% endif %}
diff --git a/datasette/templates/table.html b/datasette/templates/table.html
index eda37bc7..92684a71 100644
--- a/datasette/templates/table.html
+++ b/datasette/templates/table.html
@@ -111,7 +111,7 @@
{{ facet_info.name }}
{% if facet_hideable(facet_info.name) %}
- ✖
+ ✖
{% endif %}
diff --git a/datasette/utils.py b/datasette/utils.py
index 0281a177..61dbe910 100644
--- a/datasette/utils.py
+++ b/datasette/utils.py
@@ -170,13 +170,13 @@ def validate_sql_select(sql):
raise InvalidSql(msg)
-def path_with_added_args(request, args, path=None):
- path = path or request.path
+def path_with_added_args(qs, args, path=None):
+ path = path or qs.path
if isinstance(args, dict):
args = args.items()
args_to_remove = {k for k, v in args if v is None}
current = []
- for key, value in urllib.parse.parse_qsl(request.query_string):
+ for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in args_to_remove:
current.append((key, value))
current.extend([
@@ -190,9 +190,9 @@ def path_with_added_args(request, args, path=None):
return path + query_string
-def path_with_removed_args(request, args, path=None):
+def path_with_removed_args(qs, args, path=None):
# args can be a dict or a set
- path = path or request.path
+ path = path or qs.path
current = []
if isinstance(args, set):
def should_remove(key, value):
@@ -201,7 +201,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
- for key, value in urllib.parse.parse_qsl(request.query_string):
+ for key, value in urllib.parse.parse_qsl(str(qs)):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)
@@ -210,13 +210,13 @@ def path_with_removed_args(request, args, path=None):
return path + query_string
-def path_with_replaced_args(request, args, path=None):
- path = path or request.path
+def path_with_replaced_args(qs, args, path=None):
+ path = path or qs.path
if isinstance(args, dict):
args = args.items()
keys_to_replace = {p[0] for p in args}
current = []
- for key, value in urllib.parse.parse_qsl(request.query_string):
+ for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in keys_to_replace:
current.append((key, value))
current.extend([p for p in args if p[1] is not None])
@@ -783,23 +783,23 @@ def resolve_table_and_format(table_and_format, table_exists):
return table_and_format, None
-def path_with_format(request, format, extra_qs=None):
- qs = extra_qs or {}
- path = request.path
- if "." in request.path:
- qs["_format"] = format
+def path_with_format(qs, format, extra_qs=None):
+ new_qs = extra_qs or {}
+ path = qs.path
+ if "." in qs.path:
+ new_qs["_format"] = format
else:
path = "{}.{}".format(path, format)
- if qs:
- extra = urllib.parse.urlencode(sorted(qs.items()))
- if request.query_string:
+ if new_qs:
+ extra = urllib.parse.urlencode(sorted(new_qs.items()))
+ if qs.data:
path = "{}?{}&{}".format(
- path, request.query_string, extra
+ path, str(qs), extra
)
else:
path = "{}?{}".format(path, extra)
- elif request.query_string:
- path = "{}?{}".format(path, request.query_string)
+ elif qs.data:
+ path = "{}?{}".format(path, str(qs))
return path
@@ -831,11 +831,12 @@ class ValueAsBooleanError(ValueError):
class Querystring:
- def __init__(self, qs=None):
+ def __init__(self, path, qs=None):
+ self.path = path
self.prev = None
self.data = []
if qs:
- self.data = urllib.parse.parse_qsl(qs)
+ self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True)
def first(self, key):
for item in self.data:
@@ -843,6 +844,12 @@ class Querystring:
return item[1]
raise KeyError
+ def first_or_none(self, key):
+ try:
+ return self.first(key)
+ except KeyError:
+ return None
+
def last(self, key):
for item in reversed(self.data):
if item[0] == key:
@@ -854,10 +861,11 @@ class Querystring:
for item in self.data:
if item[0] == key:
result.append(item[1])
- if not result:
- raise KeyError
return result
+ def first_dict(self):
+ return {k: v[0] for k, v in self.data}
+
def append(self, key, value):
self.data.append((key, value))
diff --git a/datasette/views/base.py b/datasette/views/base.py
index 055d174a..36ef0540 100644
--- a/datasette/views/base.py
+++ b/datasette/views/base.py
@@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder,
InterruptedError,
InvalidSql,
+ Querystring,
path_from_row_pks,
path_with_added_args,
path_with_format,
@@ -85,9 +86,9 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*"
return r
- def redirect(self, request, path, forward_querystring=True):
- if request.query_string and "?" not in path and forward_querystring:
- path = "{}?{}".format(path, request.query_string)
+ def redirect(self, qs, path, forward_querystring=True):
+ if qs.data and "?" not in qs.path and forward_querystring:
+ path = "{}?{}".format(path, str(qs))
r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors:
@@ -144,15 +145,15 @@ class BaseView(RenderMixin):
async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
+ qs = Querystring(request.path, request.query_string)
if should_redirect:
- return self.redirect(request, should_redirect)
+ return self.redirect(qs, should_redirect)
+ return await self.view_get(qs, name, hash, **kwargs)
- return await self.view_get(request, name, hash, **kwargs)
-
- async def as_csv(self, request, name, hash, **kwargs):
+ async def as_csv(self, qs, name, hash, **kwargs):
try:
response_or_template_contexts = await self.data(
- request, name, hash, **kwargs
+ qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@@ -198,7 +199,7 @@ class BaseView(RenderMixin):
content_type = "text/plain; charset=utf-8"
headers = {}
- if request.args.get("_dl", None):
+ if qs.first_or_none("_dl"):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', name)
@@ -211,9 +212,9 @@ class BaseView(RenderMixin):
content_type=content_type
)
- async def view_get(self, request, name, hash, **kwargs):
+ async def view_get(self, qs, name, hash, **kwargs):
# If ?_format= is provided, use that as the format
- _format = request.args.get("_format", None)
+ _format = qs.first_or_none("_format")
if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs:
@@ -228,7 +229,7 @@ class BaseView(RenderMixin):
del kwargs["table_and_format"]
if _format == "csv":
- return await self.as_csv(request, name, hash, **kwargs)
+ return await self.as_csv(qs, name, hash, **kwargs)
if _format is None:
# HTML views default to expanding all forign key labels
@@ -240,7 +241,7 @@ class BaseView(RenderMixin):
templates = []
try:
response_or_template_contexts = await self.data(
- request, name, hash, **kwargs
+ qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@@ -272,26 +273,25 @@ class BaseView(RenderMixin):
# Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono":
return self.redirect(
- request,
+ qs,
path_with_added_args(
- request,
+ qs,
{"_shape": "objects"},
- path=request.path.rsplit(".jsono", 1)[0] + ".json",
+ path=qs.path.rsplit(".jsono", 1)[0] + ".json",
),
forward_querystring=False,
)
# Handle the _json= parameter which may modify data["rows"]
json_cols = []
- if "_json" in request.args:
- json_cols = request.args["_json"]
+ json_cols = qs.getlist("_json")
if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols,
)
# Deal with the _shape option
- shape = request.args.get("_shape", "arrays")
+ shape = qs.first_or_none("_shape") or "arrays"
if shape == "arrayfirst":
data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"):
@@ -353,11 +353,11 @@ class BaseView(RenderMixin):
**data,
**extras,
**{
- "url_json": path_with_format(request, "json"),
- "url_csv": path_with_format(request, "csv", {
+ "url_json": path_with_format(qs, "json"),
+ "url_csv": path_with_format(qs, "csv", {
"_size": "max"
}),
- "url_csv_dl": path_with_format(request, "csv", {
+ "url_csv_dl": path_with_format(qs, "csv", {
"_dl": "1",
"_size": "max"
}),
@@ -372,7 +372,7 @@ class BaseView(RenderMixin):
r.status = status_code
# Set far-future cache expiry
if self.ds.cache_headers:
- ttl = request.args.get("_ttl", None)
+ ttl = qs.first_or_none("_ttl")
if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"]
else:
@@ -386,9 +386,9 @@ class BaseView(RenderMixin):
return r
async def custom_sql(
- self, request, name, hash, sql, editable=True, canned_query=None
+ self, qs, name, hash, sql, editable=True, canned_query=None
):
- params = request.raw_args
+ params = qs.first_dict()
if "sql" in params:
params.pop("sql")
if "_shape" in params:
diff --git a/datasette/views/database.py b/datasette/views/database.py
index 2f3f41d3..b52b20d6 100644
--- a/datasette/views/database.py
+++ b/datasette/views/database.py
@@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView):
- async def data(self, request, name, hash, default_labels=False):
- if request.args.get("sql"):
+ async def data(self, qs, name, hash, default_labels=False):
+ if qs.first_or_none("sql"):
if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400)
- sql = request.raw_args.pop("sql")
+ sql = qs.first("sql")
validate_sql_select(sql)
- return await self.custom_sql(request, name, hash, sql)
+ return await self.custom_sql(qs, name, hash, sql)
info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {})
@@ -34,7 +34,7 @@ class DatabaseView(BaseView):
"config": self.ds.config,
}, {
"database_hash": hash,
- "show_hidden": request.args.get("_show_hidden"),
+ "show_hidden": qs.first_or_none("_show_hidden"),
"editable": True,
"metadata": metadata,
}, (
@@ -44,7 +44,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView):
- async def view_get(self, request, name, hash, **kwargs):
+ async def view_get(self, qs, name, hash, **kwargs):
if not self.ds.config["allow_download"]:
raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"]
diff --git a/datasette/views/table.py b/datasette/views/table.py
index d5a2cfb0..c139cf30 100644
--- a/datasette/views/table.py
+++ b/datasette/views/table.py
@@ -1,4 +1,3 @@
-from collections import namedtuple
import sqlite3
import urllib
@@ -220,11 +219,11 @@ class RowTableShared(BaseView):
class TableView(RowTableShared):
- async def data(self, request, name, hash, table, default_labels=False):
+ async def data(self, qs, name, hash, table, default_labels=False):
canned_query = self.ds.get_canned_query(name, table)
if canned_query is not None:
return await self.custom_sql(
- request,
+ qs,
name,
hash,
canned_query["sql"],
@@ -255,7 +254,7 @@ class TableView(RowTableShared):
# We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact=
args = RequestParameters(
- urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
+ urllib.parse.parse_qs(str(qs), keep_blank_values=True)
)
# Special args start with _ and do not contain a __
@@ -275,17 +274,17 @@ class TableView(RowTableShared):
redirect_params = filters_should_redirect(special_args)
if redirect_params:
return self.redirect(
- request,
- path_with_added_args(request, redirect_params),
+ qs,
+ path_with_added_args(qs, redirect_params),
forward_querystring=False,
)
# Spot ?_sort_by_desc and redirect to _sort_desc=(_sort)
if "_sort_by_desc" in special_args:
return self.redirect(
- request,
+ qs,
path_with_added_args(
- request,
+ qs,
{
"_sort_desc": special_args.get("_sort"),
"_sort_by_desc": None,
@@ -458,11 +457,11 @@ class TableView(RowTableShared):
table_name=escape_sqlite(table),
where=where_clause,
)
- return await self.custom_sql(request, name, hash, sql, editable=True)
+ return await self.custom_sql(qs, name, hash, sql, editable=True)
extra_args = {}
# Handle ?_size=500
- page_size = request.raw_args.get("_size")
+ page_size = qs.first_or_none("_size")
if page_size:
if page_size == "max":
page_size = self.max_returned_rows
@@ -492,8 +491,8 @@ class TableView(RowTableShared):
offset=offset,
)
- if request.raw_args.get("_timelimit"):
- extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
+ if qs.first_or_none("_timelimit"):
+ extra_args["custom_time_limit"] = int(qs.first("_timelimit"))
results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
@@ -503,10 +502,10 @@ class TableView(RowTableShared):
facet_size = self.ds.config["default_facet_size"]
metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:]
- if request.args.get("_facet") and not self.ds.config["allow_facet"]:
+ if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]:
raise DatasetteError("_facet= is not allowed", status=400)
try:
- facets.extend(request.args["_facet"])
+ facets.extend(qs.getlist("_facet"))
except KeyError:
pass
facet_results = {}
@@ -544,11 +543,11 @@ class TableView(RowTableShared):
selected = str(other_args.get(column)) == str(row["value"])
if selected:
toggle_path = path_with_removed_args(
- request, {column: str(row["value"])}
+ qs, {column: str(row["value"])}
)
else:
toggle_path = path_with_added_args(
- request, {column: row["value"]}
+ qs, {column: row["value"]}
)
facet_results_values.append({
"value": row["value"],
@@ -558,7 +557,7 @@ class TableView(RowTableShared):
),
"count": row["count"],
"toggle_url": urllib.parse.urljoin(
- request.url, toggle_path
+ qs.path, toggle_path
),
"selected": selected,
})
@@ -581,8 +580,8 @@ class TableView(RowTableShared):
except ValueError:
all_labels = default_labels
# Check for explicit _label=
- if "_label" in request.args:
- columns_to_expand = request.args["_label"]
+ if qs.first_or_none("_label"):
+ columns_to_expand = qs.first("_label")
if columns_to_expand is None and all_labels:
# expand all columns with foreign keys
columns_to_expand = [
@@ -644,7 +643,7 @@ class TableView(RowTableShared):
else:
added_args = {"_next": next_value}
next_url = urllib.parse.urljoin(
- request.url, path_with_replaced_args(request, added_args)
+ qs.path, path_with_replaced_args(qs, added_args)
)
rows = rows[:page_size]
@@ -694,7 +693,7 @@ class TableView(RowTableShared):
suggested_facets.append({
'name': facet_column,
'toggle_url': path_with_added_args(
- request, {'_facet': facet_column}
+ qs, {'_facet': facet_column}
),
})
except InterruptedError:
@@ -744,7 +743,7 @@ class TableView(RowTableShared):
"is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args,
- "request": request,
+ "qs": qs,
"sort": sort,
"sort_desc": sort_desc,
"disable_sort": is_view,
@@ -788,7 +787,7 @@ class TableView(RowTableShared):
class RowView(RowTableShared):
- async def data(self, request, name, hash, table, pk_path, default_labels=False):
+ async def data(self, qs, name, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
info = self.ds.inspect()[name]
table_info = info["tables"].get(table) or {}
@@ -854,7 +853,7 @@ class RowView(RowTableShared):
"units": self.table_metadata(name, table).get("units", {}),
}
- if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","):
+ if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables(
name, table, pk_values
)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d12bf927..85ee2d84 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -6,7 +6,6 @@ from datasette import utils
import json
import os
import pytest
-from sanic.request import Request
import sqlite3
import tempfile
from unittest.mock import patch
@@ -40,11 +39,12 @@ def test_urlsafe_components(path, expected):
), '/?_facet=state&_facet=city&_facet=planet_int'),
])
def test_path_with_added_args(path, added_args, expected):
- request = Request(
- path.encode('utf8'),
- {}, '1.1', 'GET', None
- )
- actual = utils.path_with_added_args(request, added_args)
+ try:
+ path, qsbits = path.split('?', 1)
+ except ValueError:
+ qsbits = ''
+ qs = utils.Querystring(path, qsbits)
+ actual = utils.path_with_added_args(qs, added_args)
assert expected == actual
@@ -54,11 +54,12 @@ def test_path_with_added_args(path, added_args, expected):
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
])
def test_path_with_removed_args(path, args, expected):
- request = Request(
- path.encode('utf8'),
- {}, '1.1', 'GET', None
- )
- actual = utils.path_with_removed_args(request, args)
+ try:
+ path, qsbits = path.split('?', 1)
+ except ValueError:
+ qsbits = ''
+ qs = utils.Querystring(path, qsbits)
+ actual = utils.path_with_removed_args(qs, args)
assert expected == actual
@@ -67,11 +68,12 @@ def test_path_with_removed_args(path, args, expected):
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
])
def test_path_with_replaced_args(path, args, expected):
- request = Request(
- path.encode('utf8'),
- {}, '1.1', 'GET', None
- )
- actual = utils.path_with_replaced_args(request, args)
+ try:
+ path, qsbits = path.split('?', 1)
+ except ValueError:
+ qsbits = ''
+ qs = utils.Querystring(path, qsbits)
+ actual = utils.path_with_replaced_args(qs, args)
assert expected == actual
@@ -344,9 +346,10 @@ def test_resolve_table_and_format(
],
)
def test_path_with_format(path, format, extra_qs, expected):
- request = Request(
- path.encode('utf8'),
- {}, '1.1', 'GET', None
- )
- actual = utils.path_with_format(request, format, extra_qs)
+ try:
+ path, qsbits = path.split('?', 1)
+ except ValueError:
+ qsbits = ''
+ qs = utils.Querystring(path, qsbits)
+ actual = utils.path_with_format(qs, format, extra_qs)
assert expected == actual