Compare commits

...

3 commits

Author SHA1 Message Date
Simon Willison
4fd36ba2f3
Work in progress 2018-06-17 12:21:44 -07:00
Simon Willison
68302b6ca2
Refactored everything to use new qs instead of request
I don't like this, I think I will go back to the request object but with my
own custom request object that has a request.qs property.
2018-06-16 21:45:30 -07:00
Simon Willison
f22cac8868
New Querystring manipulator 2018-06-16 19:12:10 -07:00
8 changed files with 221 additions and 99 deletions

View file

@ -16,6 +16,7 @@ from markupsafe import Markup
import pluggy
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response
from sanic.request import Request as SanicRequest
from sanic.exceptions import InvalidUsage, NotFound
from .views.base import (
@ -498,4 +499,37 @@ class Datasette:
template = self.jinja_env.select_template(templates)
return response.html(template.render(info), status=status)
class AsgiApp():
def __init__(self, scope):
self.scope = scope
async def __call__(self, receive, send):
# Create Sanic request from scope
path = self.scope["path"].encode("utf8")
if self.scope["query_string"]:
path = b"{}?{}".format(path, self.scope["query_string"])
request = SanicRequest(
path,
{}, '1.1', 'GET', None
)
async def write_callback(response):
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
})
await send({
'type': 'http.response.body',
'body': response.body,
})
# TODO: Fix this
stream_callback = write_callback
await app.handle_request(request, write_callback, stream_callback)
app.AsgiApp = AsgiApp
return app

View file

@ -7,9 +7,9 @@
{{ column.name }}
{% else %}
{% if column.name == sort %}
<a href="{{ path_with_replaced_args(request, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
<a href="{{ path_with_replaced_args(qs, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
{% else %}
<a href="{{ path_with_replaced_args(request, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
<a href="{{ path_with_replaced_args(qs, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
{% endif %}
{% endif %}
</th>

View file

@ -111,7 +111,7 @@
<p class="facet-info-name">
<strong>{{ facet_info.name }}</strong>
{% if facet_hideable(facet_info.name) %}
<a href="{{ path_with_removed_args(request, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
<a href="{{ path_with_removed_args(qs, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
{% endif %}
</p>
<ul>

View file

@ -170,13 +170,13 @@ def validate_sql_select(sql):
raise InvalidSql(msg)
def path_with_added_args(request, args, path=None):
path = path or request.path
def path_with_added_args(qs, args, path=None):
path = path or qs.path
if isinstance(args, dict):
args = args.items()
args_to_remove = {k for k, v in args if v is None}
current = []
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in args_to_remove:
current.append((key, value))
current.extend([
@ -190,9 +190,9 @@ def path_with_added_args(request, args, path=None):
return path + query_string
def path_with_removed_args(request, args, path=None):
def path_with_removed_args(qs, args, path=None):
# args can be a dict or a set
path = path or request.path
path = path or qs.path
current = []
if isinstance(args, set):
def should_remove(key, value):
@ -201,7 +201,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value
def should_remove(key, value):
return args.get(key) == value
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if not should_remove(key, value):
current.append((key, value))
query_string = urllib.parse.urlencode(current)
@ -210,13 +210,13 @@ def path_with_removed_args(request, args, path=None):
return path + query_string
def path_with_replaced_args(request, args, path=None):
path = path or request.path
def path_with_replaced_args(qs, args, path=None):
path = path or qs.path
if isinstance(args, dict):
args = args.items()
keys_to_replace = {p[0] for p in args}
current = []
for key, value in urllib.parse.parse_qsl(request.query_string):
for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in keys_to_replace:
current.append((key, value))
current.extend([p for p in args if p[1] is not None])
@ -783,23 +783,23 @@ def resolve_table_and_format(table_and_format, table_exists):
return table_and_format, None
def path_with_format(request, format, extra_qs=None):
qs = extra_qs or {}
path = request.path
if "." in request.path:
qs["_format"] = format
def path_with_format(qs, format, extra_qs=None):
new_qs = extra_qs or {}
path = qs.path
if "." in qs.path:
new_qs["_format"] = format
else:
path = "{}.{}".format(path, format)
if qs:
extra = urllib.parse.urlencode(sorted(qs.items()))
if request.query_string:
if new_qs:
extra = urllib.parse.urlencode(sorted(new_qs.items()))
if qs.data:
path = "{}?{}&{}".format(
path, request.query_string, extra
path, str(qs), extra
)
else:
path = "{}?{}".format(path, extra)
elif request.query_string:
path = "{}?{}".format(path, request.query_string)
elif qs.data:
path = "{}?{}".format(path, str(qs))
return path
@ -828,3 +828,64 @@ def value_as_boolean(value):
class ValueAsBooleanError(ValueError):
pass
class Querystring:
def __init__(self, path, qs=None):
self.path = path
self.prev = None
self.data = []
if qs:
self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True)
def first(self, key):
for item in self.data:
if item[0] == key:
return item[1]
raise KeyError
def first_or_none(self, key):
try:
return self.first(key)
except KeyError:
return None
def last(self, key):
for item in reversed(self.data):
if item[0] == key:
return item[1]
raise KeyError
def getlist(self, key):
result = []
for item in self.data:
if item[0] == key:
result.append(item[1])
return result
def first_dict(self):
return {k: v[0] for k, v in self.data}
def append(self, key, value):
self.data.append((key, value))
def remove(self, key):
self.data = [item for item in self.data if item[0] != key]
def replace(self, **kwargs):
for key, values in kwargs.items():
if not isinstance(values, list):
kwargs[key] = [values]
new_data = []
for key, value in self.data:
if key in kwargs:
new_data.append((key, kwargs[key]))
else:
new_data.append((key, value))
self.data = new_data
def __str__(self):
return urllib.parse.urlencode(self.data)
def __repr__(self):
return str(self)

View file

@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder,
InterruptedError,
InvalidSql,
Querystring,
path_from_row_pks,
path_with_added_args,
path_with_format,
@ -85,9 +86,9 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*"
return r
def redirect(self, request, path, forward_querystring=True):
if request.query_string and "?" not in path and forward_querystring:
path = "{}?{}".format(path, request.query_string)
def redirect(self, qs, path, forward_querystring=True):
if qs.data and "?" not in qs.path and forward_querystring:
path = "{}?{}".format(path, str(qs))
r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors:
@ -142,17 +143,42 @@ class BaseView(RenderMixin):
def get_templates(self, database, table=None):
assert NotImplemented
async def asgi_get(self, receive, send):
kwargs = self.scope["url_route"]["kwargs"]
db_name = kwargs.pop("db_name")
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(
self.scope["path"], self.scope["query_string"].decode("utf-8")
)
if should_redirect:
response = self.redirect(qs, should_redirect)
else:
response = await self.view_get(qs, name, hash, **kwargs)
# Send response over send() channel
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
})
await send({
'type': 'http.response.body',
'body': response.body,
})
async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(request.path, request.query_string)
if should_redirect:
return self.redirect(request, should_redirect)
return self.redirect(qs, should_redirect)
return await self.view_get(qs, name, hash, **kwargs)
return await self.view_get(request, name, hash, **kwargs)
async def as_csv(self, request, name, hash, **kwargs):
async def as_csv(self, qs, name, hash, **kwargs):
try:
response_or_template_contexts = await self.data(
request, name, hash, **kwargs
qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@ -198,7 +224,7 @@ class BaseView(RenderMixin):
content_type = "text/plain; charset=utf-8"
headers = {}
if request.args.get("_dl", None):
if qs.first_or_none("_dl"):
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', name)
@ -211,9 +237,9 @@ class BaseView(RenderMixin):
content_type=content_type
)
async def view_get(self, request, name, hash, **kwargs):
async def view_get(self, qs, name, hash, **kwargs):
# If ?_format= is provided, use that as the format
_format = request.args.get("_format", None)
_format = qs.first_or_none("_format")
if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs:
@ -228,7 +254,7 @@ class BaseView(RenderMixin):
del kwargs["table_and_format"]
if _format == "csv":
return await self.as_csv(request, name, hash, **kwargs)
return await self.as_csv(qs, name, hash, **kwargs)
if _format is None:
# HTML views default to expanding all forign key labels
@ -240,7 +266,7 @@ class BaseView(RenderMixin):
templates = []
try:
response_or_template_contexts = await self.data(
request, name, hash, **kwargs
qs, name, hash, **kwargs
)
if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts
@ -272,26 +298,25 @@ class BaseView(RenderMixin):
# Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono":
return self.redirect(
request,
qs,
path_with_added_args(
request,
qs,
{"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json",
path=qs.path.rsplit(".jsono", 1)[0] + ".json",
),
forward_querystring=False,
)
# Handle the _json= parameter which may modify data["rows"]
json_cols = []
if "_json" in request.args:
json_cols = request.args["_json"]
json_cols = qs.getlist("_json")
if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols,
)
# Deal with the _shape option
shape = request.args.get("_shape", "arrays")
shape = qs.first_or_none("_shape") or "arrays"
if shape == "arrayfirst":
data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"):
@ -353,11 +378,11 @@ class BaseView(RenderMixin):
**data,
**extras,
**{
"url_json": path_with_format(request, "json"),
"url_csv": path_with_format(request, "csv", {
"url_json": path_with_format(qs, "json"),
"url_csv": path_with_format(qs, "csv", {
"_size": "max"
}),
"url_csv_dl": path_with_format(request, "csv", {
"url_csv_dl": path_with_format(qs, "csv", {
"_dl": "1",
"_size": "max"
}),
@ -372,7 +397,7 @@ class BaseView(RenderMixin):
r.status = status_code
# Set far-future cache expiry
if self.ds.cache_headers:
ttl = request.args.get("_ttl", None)
ttl = qs.first_or_none("_ttl")
if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"]
else:
@ -386,9 +411,9 @@ class BaseView(RenderMixin):
return r
async def custom_sql(
self, request, name, hash, sql, editable=True, canned_query=None
self, qs, name, hash, sql, editable=True, canned_query=None
):
params = request.raw_args
params = qs.first_dict()
if "sql" in params:
params.pop("sql")
if "_shape" in params:

View file

@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False):
if request.args.get("sql"):
async def data(self, qs, name, hash, default_labels=False):
if qs.first_or_none("sql"):
if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql")
sql = qs.first("sql")
validate_sql_select(sql)
return await self.custom_sql(request, name, hash, sql)
return await self.custom_sql(qs, name, hash, sql)
info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {})
@ -34,7 +34,7 @@ class DatabaseView(BaseView):
"config": self.ds.config,
}, {
"database_hash": hash,
"show_hidden": request.args.get("_show_hidden"),
"show_hidden": qs.first_or_none("_show_hidden"),
"editable": True,
"metadata": metadata,
}, (
@ -44,7 +44,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView):
async def view_get(self, request, name, hash, **kwargs):
async def view_get(self, qs, name, hash, **kwargs):
if not self.ds.config["allow_download"]:
raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"]

View file

@ -1,4 +1,3 @@
from collections import namedtuple
import sqlite3
import urllib
@ -220,11 +219,11 @@ class RowTableShared(BaseView):
class TableView(RowTableShared):
async def data(self, request, name, hash, table, default_labels=False):
async def data(self, qs, name, hash, table, default_labels=False):
canned_query = self.ds.get_canned_query(name, table)
if canned_query is not None:
return await self.custom_sql(
request,
qs,
name,
hash,
canned_query["sql"],
@ -255,7 +254,7 @@ class TableView(RowTableShared):
# We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact=
args = RequestParameters(
urllib.parse.parse_qs(request.query_string, keep_blank_values=True)
urllib.parse.parse_qs(str(qs), keep_blank_values=True)
)
# Special args start with _ and do not contain a __
@ -275,17 +274,17 @@ class TableView(RowTableShared):
redirect_params = filters_should_redirect(special_args)
if redirect_params:
return self.redirect(
request,
path_with_added_args(request, redirect_params),
qs,
path_with_added_args(qs, redirect_params),
forward_querystring=False,
)
# Spot ?_sort_by_desc and redirect to _sort_desc=(_sort)
if "_sort_by_desc" in special_args:
return self.redirect(
request,
qs,
path_with_added_args(
request,
qs,
{
"_sort_desc": special_args.get("_sort"),
"_sort_by_desc": None,
@ -458,11 +457,11 @@ class TableView(RowTableShared):
table_name=escape_sqlite(table),
where=where_clause,
)
return await self.custom_sql(request, name, hash, sql, editable=True)
return await self.custom_sql(qs, name, hash, sql, editable=True)
extra_args = {}
# Handle ?_size=500
page_size = request.raw_args.get("_size")
page_size = qs.first_or_none("_size")
if page_size:
if page_size == "max":
page_size = self.max_returned_rows
@ -492,8 +491,8 @@ class TableView(RowTableShared):
offset=offset,
)
if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
if qs.first_or_none("_timelimit"):
extra_args["custom_time_limit"] = int(qs.first("_timelimit"))
results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args
@ -503,10 +502,10 @@ class TableView(RowTableShared):
facet_size = self.ds.config["default_facet_size"]
metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:]
if request.args.get("_facet") and not self.ds.config["allow_facet"]:
if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]:
raise DatasetteError("_facet= is not allowed", status=400)
try:
facets.extend(request.args["_facet"])
facets.extend(qs.getlist("_facet"))
except KeyError:
pass
facet_results = {}
@ -544,11 +543,11 @@ class TableView(RowTableShared):
selected = str(other_args.get(column)) == str(row["value"])
if selected:
toggle_path = path_with_removed_args(
request, {column: str(row["value"])}
qs, {column: str(row["value"])}
)
else:
toggle_path = path_with_added_args(
request, {column: row["value"]}
qs, {column: row["value"]}
)
facet_results_values.append({
"value": row["value"],
@ -558,7 +557,7 @@ class TableView(RowTableShared):
),
"count": row["count"],
"toggle_url": urllib.parse.urljoin(
request.url, toggle_path
qs.path, toggle_path
),
"selected": selected,
})
@ -581,8 +580,8 @@ class TableView(RowTableShared):
except ValueError:
all_labels = default_labels
# Check for explicit _label=
if "_label" in request.args:
columns_to_expand = request.args["_label"]
if qs.first_or_none("_label"):
columns_to_expand = qs.first("_label")
if columns_to_expand is None and all_labels:
# expand all columns with foreign keys
columns_to_expand = [
@ -644,7 +643,7 @@ class TableView(RowTableShared):
else:
added_args = {"_next": next_value}
next_url = urllib.parse.urljoin(
request.url, path_with_replaced_args(request, added_args)
qs.path, path_with_replaced_args(qs, added_args)
)
rows = rows[:page_size]
@ -694,7 +693,7 @@ class TableView(RowTableShared):
suggested_facets.append({
'name': facet_column,
'toggle_url': path_with_added_args(
request, {'_facet': facet_column}
qs, {'_facet': facet_column}
),
})
except InterruptedError:
@ -744,7 +743,7 @@ class TableView(RowTableShared):
"is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args,
"request": request,
"qs": qs,
"sort": sort,
"sort_desc": sort_desc,
"disable_sort": is_view,
@ -788,7 +787,7 @@ class TableView(RowTableShared):
class RowView(RowTableShared):
async def data(self, request, name, hash, table, pk_path, default_labels=False):
async def data(self, qs, name, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path)
info = self.ds.inspect()[name]
table_info = info["tables"].get(table) or {}
@ -854,7 +853,7 @@ class RowView(RowTableShared):
"units": self.table_metadata(name, table).get("units", {}),
}
if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","):
if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables(
name, table, pk_values
)

View file

@ -6,7 +6,6 @@ from datasette import utils
import json
import os
import pytest
from sanic.request import Request
import sqlite3
import tempfile
from unittest.mock import patch
@ -40,11 +39,12 @@ def test_urlsafe_components(path, expected):
), '/?_facet=state&_facet=city&_facet=planet_int'),
])
def test_path_with_added_args(path, added_args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_added_args(request, added_args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_added_args(qs, added_args)
assert expected == actual
@ -54,11 +54,12 @@ def test_path_with_added_args(path, added_args, expected):
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
])
def test_path_with_removed_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_removed_args(request, args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_removed_args(qs, args)
assert expected == actual
@ -67,11 +68,12 @@ def test_path_with_removed_args(path, args, expected):
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
])
def test_path_with_replaced_args(path, args, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_replaced_args(request, args)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_replaced_args(qs, args)
assert expected == actual
@ -344,9 +346,10 @@ def test_resolve_table_and_format(
],
)
def test_path_with_format(path, format, extra_qs, expected):
request = Request(
path.encode('utf8'),
{}, '1.1', 'GET', None
)
actual = utils.path_with_format(request, format, extra_qs)
try:
path, qsbits = path.split('?', 1)
except ValueError:
qsbits = ''
qs = utils.Querystring(path, qsbits)
actual = utils.path_with_format(qs, format, extra_qs)
assert expected == actual