Refactored everything to use new qs instead of request

I don't like this, I think I will go back to the request object but with my
own custom request object that has a request.qs property.
This commit is contained in:
Simon Willison 2018-06-16 21:45:30 -07:00
commit 68302b6ca2
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
7 changed files with 113 additions and 103 deletions

View file

@ -7,9 +7,9 @@
{{ column.name }}
{% else %}
{% if column.name == sort %}
<a href="{{ path_with_replaced_args(request, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
<a href="{{ path_with_replaced_args(qs, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
{% else %}
<a href="{{ path_with_replaced_args(request, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
<a href="{{ path_with_replaced_args(qs, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
{% endif %}
{% endif %}
</th>

View file

@ -111,7 +111,7 @@
<p class="facet-info-name">
<strong>{{ facet_info.name }}</strong>
{% if facet_hideable(facet_info.name) %}
<a href="{{ path_with_removed_args(request, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
<a href="{{ path_with_removed_args(qs, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
{% endif %}
</p>
<ul>

View file

@ -170,13 +170,13 @@ def validate_sql_select(sql):
raise InvalidSql(msg)
def path_with_added_args(request, args, path=None):
path = path or request.path
def path_with_added_args(qs, args, path=None):
path = path or qs.path
if isinstance(args, dict):
args = args.items()
args_to_remove = {k for k, v in args if v is None}
current = []
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in args_to_remove:
current.append((key, value))
current.extend([
@ -190,9 +190,9 @@ def path_with_added_args(request, args, path=None):
return path + query_string
def path_with_removed_args(request, args, path=None):
def path_with_removed_args(qs, args, path=None):
# args can be a dict or a set
path = path or request.path
path = path or qs.path
current = []
if isinstance(args, set):
def should_remove(key, value):
@ -201,7 +201,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)
@ -210,13 +210,13 @@ def path_with_removed_args(request, args, path=None):
return path + query_string
def path_with_replaced_args(request, args, path=None):
path = path or request.path
def path_with_replaced_args(qs, args, path=None):
path = path or qs.path
if isinstance(args, dict):
args = args.items()
keys_to_replace = {p[0] for p in args}
current = []
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in keys_to_replace:
current.append((key, value))
current.extend([p for p in args if p[1] is not None])
@ -783,23 +783,23 @@ def resolve_table_and_format(table_and_format, table_exists):
return table_and_format, None
def path_with_format(request, format, extra_qs=None):
qs = extra_qs or {}
path = request.path
if "." in request.path:
qs["_format"] = format
def path_with_format(qs, format, extra_qs=None):
new_qs = extra_qs or {}
path = qs.path
if "." in qs.path:
new_qs["_format"] = format
else:
path = "{}.{}".format(path, format)
if qs:
extra = urllib.parse.urlencode(sorted(qs.items()))
if request.query_string:
if new_qs:
extra = urllib.parse.urlencode(sorted(new_qs.items()))
if qs.data:
path = "{}?{}&{}".format(
path, request.query_string, extra
path, str(qs), extra
)
else:
path = "{}?{}".format(path, extra)
elif request.query_string:
path = "{}?{}".format(path, request.query_string)
elif qs.data:
path = "{}?{}".format(path, str(qs))
return path
@ -831,11 +831,12 @@ class ValueAsBooleanError(ValueError):
class Querystring:
def __init__(self, qs=None):
def __init__(self, path, qs=None):
self.path = path
self.prev = None
self.data = []
if qs:
self.data = urllib.parse.parse_qsl(qs)
self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True)
def first(self, key):
for item in self.data:
@ -843,6 +844,12 @@ class Querystring:
return item[1]
raise KeyError
def first_or_none(self, key):
try:
return self.first(key)
except KeyError:
return None
def last(self, key):
for item in reversed(self.data):
if item[0] == key:
@ -854,10 +861,11 @@ class Querystring:
for item in self.data:
if item[0] == key:
result.append(item[1])
if not result:
raise KeyError
return result
def first_dict(self):
return {k: v[0] for k, v in self.data}
def append(self, key, value):
self.data.append((key, value))

View file

@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder,
InterruptedError,
InvalidSql,
Querystring,
path_from_row_pks,
path_with_added_args,
path_with_format,
@ -85,9 +86,9 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*"
return r
def redirect(self, request, path, forward_querystring=True):
if request.query_string and "?" not in path and forward_querystring:
path = "{}?{}".format(path, request.query_string)
def redirect(self, qs, path, forward_querystring=True):
if qs.data and "?" not in qs.path and forward_querystring:
path = "{}?{}".format(path, str(qs))
r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors:
@ -144,15 +145,15 @@ class BaseView(RenderMixin):
async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(request.path, request.query_string)
if should_redirect:
return self.redirect(request, should_redirect)
return self.redirect(qs, should_redirect)
return await self.view_get(qs, name, hash, **kwargs)
return await self.view_get(request, name, hash, **kwargs)
async def as_csv(self, request, name, hash, **kwargs):
async def as_csv(self, qs, name, hash, **kwargs):
try:
response_or_template_contexts = await self.data(
request, name, hash, **kwargs
qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@ -198,7 +199,7 @@ class BaseView(RenderMixin):
content_type = "text/plain; charset=utf-8"
headers = {}
if request.args.get("_dl", None):
if qs.first_or_none("_dl"):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', name)
@ -211,9 +212,9 @@ class BaseView(RenderMixin):
content_type=content_type
)
async def view_get(self, request, name, hash, **kwargs):
async def view_get(self, qs, name, hash, **kwargs):
# If ?_format= is provided, use that as the format
_format = request.args.get("_format", None)
_format = qs.first_or_none("_format")
if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs:
@ -228,7 +229,7 @@ class BaseView(RenderMixin):
del kwargs["table_and_format"]
if _format == "csv":
return await self.as_csv(request, name, hash, **kwargs)
return await self.as_csv(qs, name, hash, **kwargs)
if _format is None:
# HTML views default to expanding all forign key labels
@ -240,7 +241,7 @@ class BaseView(RenderMixin):
templates = []
try:
response_or_template_contexts = await self.data(
request, name, hash, **kwargs
qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@ -272,26 +273,25 @@ class BaseView(RenderMixin):
# Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono":
return self.redirect(
request,
qs,
path_with_added_args(
request,
qs,
{"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json",
path=qs.path.rsplit(".jsono", 1)[0] + ".json",
),
forward_querystring=False,
)
# Handle the _json= parameter which may modify data["rows"]
json_cols = []
if "_json" in request.args:
json_cols = request.args["_json"]
json_cols = qs.getlist("_json")
if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols,
)
# Deal with the _shape option
shape = request.args.get("_shape", "arrays")
shape = qs.first_or_none("_shape") or "arrays"
if shape == "arrayfirst":
data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"):
@ -353,11 +353,11 @@ class BaseView(RenderMixin):
**data,
**extras,
**{
"url_json": path_with_format(request, "json"),
"url_csv": path_with_format(request, "csv", {
"url_json": path_with_format(qs, "json"),
"url_csv": path_with_format(qs, "csv", {
"_size": "max"
}),
"url_csv_dl": path_with_format(request, "csv", {
"url_csv_dl": path_with_format(qs, "csv", {
"_dl": "1",
"_size": "max"
}),
@ -372,7 +372,7 @@ class BaseView(RenderMixin):
r.status = status_code
# Set far-future cache expiry
if self.ds.cache_headers:
ttl = request.args.get("_ttl", None)
ttl = qs.first_or_none("_ttl")
if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"]
else:
@ -386,9 +386,9 @@ class BaseView(RenderMixin):
return r
async def custom_sql(
self, request, name, hash, sql, editable=True, canned_query=None
self, qs, name, hash, sql, editable=True, canned_query=None
):
params = request.raw_args
params = qs.first_dict()
if "sql" in params:
params.pop("sql")
if "_shape" in params:

View file

@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False):
if request.args.get("sql"):
async def data(self, qs, name, hash, default_labels=False):
if qs.first_or_none("sql"):
if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql")
sql = qs.first("sql")
validate_sql_select(sql)
return await self.custom_sql(request, name, hash, sql)
return await self.custom_sql(qs, name, hash, sql)
info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {})
@ -34,7 +34,7 @@ class DatabaseView(BaseView):
"config": self.ds.config,
}, {
"database_hash": hash,
"show_hidden": request.args.get("_show_hidden"),
"show_hidden": qs.first_or_none("_show_hidden"),
"editable": True,
"metadata": metadata,
}, (
@ -44,7 +44,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView):
async def view_get(self, request, name, hash, **kwargs):
async def view_get(self, qs, name, hash, **kwargs):
if not self.ds.config["allow_download"]:
raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"]

View file

@ -1,4 +1,3 @@
from collections import namedtuple
import sqlite3
import urllib
@ -220,11 +219,11 @@ class RowTableShared(BaseView):
class TableView(RowTableShared):
async def data(self, request, name, hash, table, default_labels=False):
async def data(self, qs, name, hash, table, default_labels=False):
canned_query = self.ds.get_canned_query(name, table)
if canned_query is not None:
return await self.custom_sql(
request,
qs,
name,
hash,
canned_query["sql"],
@ -255,7 +254,7 @@ class TableView(RowTableShared):
# We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact=
args = RequestParameters(
urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
urllib.parse.parse_qs(str(qs), keep_blank_values=True)
)
# Special args start with _ and do not contain a __
@ -275,17 +274,17 @@ class TableView(RowTableShared):
redirect_params = filters_should_redirect(special_args)
if redirect_params:
return self.redirect(
request,
path_with_added_args(request, redirect_params),
qs,
path_with_added_args(qs, redirect_params),
forward_querystring=False,
)
# Spot ?_sort_by_desc and redirect to _sort_desc=(_sort)
if "_sort_by_desc" in special_args:
return self.redirect(
request,
qs,
path_with_added_args(
request,
qs,
{
"_sort_desc": special_args.get("_sort"),
"_sort_by_desc": None,
@ -458,11 +457,11 @@ class TableView(RowTableShared):
table_name=escape_sqlite(table),
where=where_clause,
)
return await self.custom_sql(request, name, hash, sql, editable=True)
return await self.custom_sql(qs, name, hash, sql, editable=True)
extra_args = {}
# Handle ?_size=500
page_size = request.raw_args.get("_size")
page_size = qs.first_or_none("_size")
if page_size:
if page_size == "max":
page_size = self.max_returned_rows
@ -492,8 +491,8 @@ class TableView(RowTableShared):
offset=offset,
)
if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
if qs.first_or_none("_timelimit"):
extra_args["custom_time_limit"] = int(qs.first("_timelimit"))
results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
@ -503,10 +502,10 @@ class TableView(RowTableShared):
facet_size = self.ds.config["default_facet_size"]
metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:]
if request.args.get("_facet") and not self.ds.config["allow_facet"]:
if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]:
raise DatasetteError("_facet= is not allowed", status=400)
try:
facets.extend(request.args["_facet"])
facets.extend(qs.getlist("_facet"))
except KeyError:
pass
facet_results = {}
@ -544,11 +543,11 @@ class TableView(RowTableShared):
selected = str(other_args.get(column)) == str(row["value"])
if selected:
toggle_path = path_with_removed_args(
request, {column: str(row["value"])}
qs, {column: str(row["value"])}
)
else:
toggle_path = path_with_added_args(
request, {column: row["value"]}
qs, {column: row["value"]}
)
facet_results_values.append({
"value": row["value"],
@ -558,7 +557,7 @@ class TableView(RowTableShared):
),
"count": row["count"],
"toggle_url": urllib.parse.urljoin(
request.url, toggle_path
qs.path, toggle_path
),
"selected": selected,
})
@ -581,8 +580,8 @@ class TableView(RowTableShared):
except ValueError:
all_labels = default_labels
# Check for explicit _label=
if "_label" in request.args:
columns_to_expand = request.args["_label"]
if qs.first_or_none("_label"):
columns_to_expand = qs.first("_label")
if columns_to_expand is None and all_labels:
# expand all columns with foreign keys
columns_to_expand = [
@ -644,7 +643,7 @@ class TableView(RowTableShared):
else:
added_args = {"_next": next_value}
next_url = urllib.parse.urljoin(
request.url, path_with_replaced_args(request, added_args)
qs.path, path_with_replaced_args(qs, added_args)
)
rows = rows[:page_size]
@ -694,7 +693,7 @@ class TableView(RowTableShared):
suggested_facets.append({
'name': facet_column,
'toggle_url': path_with_added_args(
request, {'_facet': facet_column}
qs, {'_facet': facet_column}
),
})
except InterruptedError:
@ -744,7 +743,7 @@ class TableView(RowTableShared):
"is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args,
"request": request,
"qs": qs,
"sort": sort,
"sort_desc": sort_desc,
"disable_sort": is_view,
@ -788,7 +787,7 @@ class TableView(RowTableShared):
class RowView(RowTableShared):
async def data(self, request, name, hash, table, pk_path, default_labels=False):
async def data(self, qs, name, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
info = self.ds.inspect()[name]
table_info = info["tables"].get(table) or {}
@ -854,7 +853,7 @@ class RowView(RowTableShared):
"units": self.table_metadata(name, table).get("units", {}),
}
if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","):
if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables(
name, table, pk_values
)

View file

@ -6,7 +6,6 @@ from datasette import utils
import json
import os
import pytest
from sanic.request import Request
import sqlite3
import tempfile
from unittest.mock import patch
@ -40,11 +39,12 @@ def test_urlsafe_components(path, expected):
), '/?_facet=state&_facet=city&_facet=planet_int'),
])
def test_path_with_added_args(path, added_args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_added_args(request, added_args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_added_args(qs, added_args)
assert expected == actual
@ -54,11 +54,12 @@ def test_path_with_added_args(path, added_args, expected):
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
])
def test_path_with_removed_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_removed_args(request, args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_removed_args(qs, args)
assert expected == actual
@ -67,11 +68,12 @@ def test_path_with_removed_args(path, args, expected):
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
])
def test_path_with_replaced_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_replaced_args(request, args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_replaced_args(qs, args)
assert expected == actual
@ -344,9 +346,10 @@ def test_resolve_table_and_format(
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_format(request, format, extra_qs)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_format(qs, format, extra_qs)
assert expected == actual