Compare commits

...

3 commits

Author SHA1 Message Date
Simon Willison
4fd36ba2f3
Work in progress 2018-06-17 12:21:44 -07:00
Simon Willison
68302b6ca2
Refactored everything to use new qs instead of request
I don't like this, I think I will go back to the request object but with my
own custom request object that has a request.qs property.
2018-06-16 21:45:30 -07:00
Simon Willison
f22cac8868
New Querystring manipulator 2018-06-16 19:12:10 -07:00
8 changed files with 221 additions and 99 deletions

View file

@ -16,6 +16,7 @@ from markupsafe import Markup
import pluggy import pluggy
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader
from sanic import Sanic, response from sanic import Sanic, response
from sanic.request import Request as SanicRequest
from sanic.exceptions import InvalidUsage, NotFound from sanic.exceptions import InvalidUsage, NotFound
from .views.base import ( from .views.base import (
@ -498,4 +499,37 @@ class Datasette:
template = self.jinja_env.select_template(templates) template = self.jinja_env.select_template(templates)
return response.html(template.render(info), status=status) return response.html(template.render(info), status=status)
class AsgiApp():
def __init__(self, scope):
self.scope = scope
async def __call__(self, receive, send):
# Create Sanic request from scope
path = self.scope["path"].encode("utf8")
if self.scope["query_string"]:
path = b"{}?{}".format(path, self.scope["query_string"])
request = SanicRequest(
path,
{}, '1.1', 'GET', None
)
async def write_callback(response):
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
})
await send({
'type': 'http.response.body',
'body': response.body,
})
# TODO: Fix this
stream_callback = write_callback
await app.handle_request(request, write_callback, stream_callback)
app.AsgiApp = AsgiApp
return app return app

View file

@ -7,9 +7,9 @@
{{ column.name }} {{ column.name }}
{% else %} {% else %}
{% if column.name == sort %} {% if column.name == sort %}
<a href="{{ path_with_replaced_args(request, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a> <a href="{{ path_with_replaced_args(qs, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
{% else %} {% else %}
<a href="{{ path_with_replaced_args(request, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a> <a href="{{ path_with_replaced_args(qs, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
{% endif %} {% endif %}
{% endif %} {% endif %}
</th> </th>

View file

@ -111,7 +111,7 @@
<p class="facet-info-name"> <p class="facet-info-name">
<strong>{{ facet_info.name }}</strong> <strong>{{ facet_info.name }}</strong>
{% if facet_hideable(facet_info.name) %} {% if facet_hideable(facet_info.name) %}
<a href="{{ path_with_removed_args(request, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a> <a href="{{ path_with_removed_args(qs, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
{% endif %} {% endif %}
</p> </p>
<ul> <ul>

View file

@ -170,13 +170,13 @@ def validate_sql_select(sql):
raise InvalidSql(msg) raise InvalidSql(msg)
def path_with_added_args(request, args, path=None): def path_with_added_args(qs, args, path=None):
path = path or request.path path = path or qs.path
if isinstance(args, dict): if isinstance(args, dict):
args = args.items() args = args.items()
args_to_remove = {k for k, v in args if v is None} args_to_remove = {k for k, v in args if v is None}
current = [] current = []
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in args_to_remove: if key not in args_to_remove:
current.append((key, value)) current.append((key, value))
current.extend([ current.extend([
@ -190,9 +190,9 @@ def path_with_added_args(request, args, path=None):
return path + query_string return path + query_string
def path_with_removed_args(request, args, path=None): def path_with_removed_args(qs, args, path=None):
# args can be a dict or a set # args can be a dict or a set
path = path or request.path path = path or qs.path
current = [] current = []
if isinstance(args, set): if isinstance(args, set):
def should_remove(key, value): def should_remove(key, value):
@ -201,7 +201,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value # Must match key AND value
def should_remove(key, value): def should_remove(key, value):
return args.get(key) == value return args.get(key) == value
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if not should_remove(key, value): if not should_remove(key, value):
current.append((key, value)) current.append((key, value))
query_string = urllib.parse.urlencode(current) query_string = urllib.parse.urlencode(current)
@ -210,13 +210,13 @@ def path_with_removed_args(request, args, path=None):
return path + query_string return path + query_string
def path_with_replaced_args(request, args, path=None): def path_with_replaced_args(qs, args, path=None):
path = path or request.path path = path or qs.path
if isinstance(args, dict): if isinstance(args, dict):
args = args.items() args = args.items()
keys_to_replace = {p[0] for p in args} keys_to_replace = {p[0] for p in args}
current = [] current = []
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in keys_to_replace: if key not in keys_to_replace:
current.append((key, value)) current.append((key, value))
current.extend([p for p in args if p[1] is not None]) current.extend([p for p in args if p[1] is not None])
@ -783,23 +783,23 @@ def resolve_table_and_format(table_and_format, table_exists):
return table_and_format, None return table_and_format, None
def path_with_format(request, format, extra_qs=None): def path_with_format(qs, format, extra_qs=None):
qs = extra_qs or {} new_qs = extra_qs or {}
path = request.path path = qs.path
if "." in request.path: if "." in qs.path:
qs["_format"] = format new_qs["_format"] = format
else: else:
path = "{}.{}".format(path, format) path = "{}.{}".format(path, format)
if qs: if new_qs:
extra = urllib.parse.urlencode(sorted(qs.items())) extra = urllib.parse.urlencode(sorted(new_qs.items()))
if request.query_string: if qs.data:
path = "{}?{}&{}".format( path = "{}?{}&{}".format(
path, request.query_string, extra path, str(qs), extra
) )
else: else:
path = "{}?{}".format(path, extra) path = "{}?{}".format(path, extra)
elif request.query_string: elif qs.data:
path = "{}?{}".format(path, request.query_string) path = "{}?{}".format(path, str(qs))
return path return path
@ -828,3 +828,64 @@ def value_as_boolean(value):
class ValueAsBooleanError(ValueError): class ValueAsBooleanError(ValueError):
pass pass
class Querystring:
def __init__(self, path, qs=None):
self.path = path
self.prev = None
self.data = []
if qs:
self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True)
def first(self, key):
for item in self.data:
if item[0] == key:
return item[1]
raise KeyError
def first_or_none(self, key):
try:
return self.first(key)
except KeyError:
return None
def last(self, key):
for item in reversed(self.data):
if item[0] == key:
return item[1]
raise KeyError
def getlist(self, key):
result = []
for item in self.data:
if item[0] == key:
result.append(item[1])
return result
def first_dict(self):
return {k: v[0] for k, v in self.data}
def append(self, key, value):
self.data.append((key, value))
def remove(self, key):
self.data = [item for item in self.data if item[0] != key]
def replace(self, **kwargs):
for key, values in kwargs.items():
if not isinstance(values, list):
kwargs[key] = [values]
new_data = []
for key, value in self.data:
if key in kwargs:
new_data.append((key, kwargs[key]))
else:
new_data.append((key, value))
self.data = new_data
def __str__(self):
return urllib.parse.urlencode(self.data)
def __repr__(self):
return str(self)

View file

@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder, CustomJSONEncoder,
InterruptedError, InterruptedError,
InvalidSql, InvalidSql,
Querystring,
path_from_row_pks, path_from_row_pks,
path_with_added_args, path_with_added_args,
path_with_format, path_with_format,
@ -85,9 +86,9 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*" r.headers["Access-Control-Allow-Origin"] = "*"
return r return r
def redirect(self, request, path, forward_querystring=True): def redirect(self, qs, path, forward_querystring=True):
if request.query_string and "?" not in path and forward_querystring: if qs.data and "?" not in qs.path and forward_querystring:
path = "{}?{}".format(path, request.query_string) path = "{}?{}".format(path, str(qs))
r = response.redirect(path) r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path) r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors: if self.ds.cors:
@ -142,17 +143,42 @@ class BaseView(RenderMixin):
def get_templates(self, database, table=None): def get_templates(self, database, table=None):
assert NotImplemented assert NotImplemented
async def asgi_get(self, receive, send):
kwargs = self.scope["url_route"]["kwargs"]
db_name = kwargs.pop("db_name")
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(
self.scope["path"], self.scope["query_string"].decode("utf-8")
)
if should_redirect:
response = self.redirect(qs, should_redirect)
else:
response = await self.view_get(qs, name, hash, **kwargs)
# Send response over send() channel
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
[key.encode("utf-8"), value.encode("utf-8")]
for key, value in response.headers.items()
],
})
await send({
'type': 'http.response.body',
'body': response.body,
})
async def get(self, request, db_name, **kwargs): async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(request.path, request.query_string)
if should_redirect: if should_redirect:
return self.redirect(request, should_redirect) return self.redirect(qs, should_redirect)
return await self.view_get(qs, name, hash, **kwargs)
return await self.view_get(request, name, hash, **kwargs) async def as_csv(self, qs, name, hash, **kwargs):
async def as_csv(self, request, name, hash, **kwargs):
try: try:
response_or_template_contexts = await self.data( response_or_template_contexts = await self.data(
request, name, hash, **kwargs qs, name, hash, **kwargs
) )
if isinstance(response_or_template_contexts, response.HTTPResponse): if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts return response_or_template_contexts
@ -198,7 +224,7 @@ class BaseView(RenderMixin):
content_type = "text/plain; charset=utf-8" content_type = "text/plain; charset=utf-8"
headers = {} headers = {}
if request.args.get("_dl", None): if qs.first_or_none("_dl"):
content_type = "text/csv; charset=utf-8" content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format( disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', name) kwargs.get('table', name)
@ -211,9 +237,9 @@ class BaseView(RenderMixin):
content_type=content_type content_type=content_type
) )
async def view_get(self, request, name, hash, **kwargs): async def view_get(self, qs, name, hash, **kwargs):
# If ?_format= is provided, use that as the format # If ?_format= is provided, use that as the format
_format = request.args.get("_format", None) _format = qs.first_or_none("_format")
if not _format: if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".") _format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs: if "table_and_format" in kwargs:
@ -228,7 +254,7 @@ class BaseView(RenderMixin):
del kwargs["table_and_format"] del kwargs["table_and_format"]
if _format == "csv": if _format == "csv":
return await self.as_csv(request, name, hash, **kwargs) return await self.as_csv(qs, name, hash, **kwargs)
if _format is None: if _format is None:
# HTML views default to expanding all forign key labels # HTML views default to expanding all forign key labels
@ -240,7 +266,7 @@ class BaseView(RenderMixin):
templates = [] templates = []
try: try:
response_or_template_contexts = await self.data( response_or_template_contexts = await self.data(
request, name, hash, **kwargs qs, name, hash, **kwargs
) )
if isinstance(response_or_template_contexts, response.HTTPResponse): if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts return response_or_template_contexts
@ -272,26 +298,25 @@ class BaseView(RenderMixin):
# Special case for .jsono extension - redirect to _shape=objects # Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono": if _format == "jsono":
return self.redirect( return self.redirect(
request, qs,
path_with_added_args( path_with_added_args(
request, qs,
{"_shape": "objects"}, {"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json", path=qs.path.rsplit(".jsono", 1)[0] + ".json",
), ),
forward_querystring=False, forward_querystring=False,
) )
# Handle the _json= parameter which may modify data["rows"] # Handle the _json= parameter which may modify data["rows"]
json_cols = [] json_cols = []
if "_json" in request.args: json_cols = qs.getlist("_json")
json_cols = request.args["_json"]
if json_cols and "rows" in data and "columns" in data: if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json( data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols, data["rows"], data["columns"], json_cols,
) )
# Deal with the _shape option # Deal with the _shape option
shape = request.args.get("_shape", "arrays") shape = qs.first_or_none("_shape") or "arrays"
if shape == "arrayfirst": if shape == "arrayfirst":
data = [row[0] for row in data["rows"]] data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"): elif shape in ("objects", "object", "array"):
@ -353,11 +378,11 @@ class BaseView(RenderMixin):
**data, **data,
**extras, **extras,
**{ **{
"url_json": path_with_format(request, "json"), "url_json": path_with_format(qs, "json"),
"url_csv": path_with_format(request, "csv", { "url_csv": path_with_format(qs, "csv", {
"_size": "max" "_size": "max"
}), }),
"url_csv_dl": path_with_format(request, "csv", { "url_csv_dl": path_with_format(qs, "csv", {
"_dl": "1", "_dl": "1",
"_size": "max" "_size": "max"
}), }),
@ -372,7 +397,7 @@ class BaseView(RenderMixin):
r.status = status_code r.status = status_code
# Set far-future cache expiry # Set far-future cache expiry
if self.ds.cache_headers: if self.ds.cache_headers:
ttl = request.args.get("_ttl", None) ttl = qs.first_or_none("_ttl")
if ttl is None or not ttl.isdigit(): if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"] ttl = self.ds.config["default_cache_ttl"]
else: else:
@ -386,9 +411,9 @@ class BaseView(RenderMixin):
return r return r
async def custom_sql( async def custom_sql(
self, request, name, hash, sql, editable=True, canned_query=None self, qs, name, hash, sql, editable=True, canned_query=None
): ):
params = request.raw_args params = qs.first_dict()
if "sql" in params: if "sql" in params:
params.pop("sql") params.pop("sql")
if "_shape" in params: if "_shape" in params:

View file

@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView): class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False): async def data(self, qs, name, hash, default_labels=False):
if request.args.get("sql"): if qs.first_or_none("sql"):
if not self.ds.config["allow_sql"]: if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400) raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql") sql = qs.first("sql")
validate_sql_select(sql) validate_sql_select(sql)
return await self.custom_sql(request, name, hash, sql) return await self.custom_sql(qs, name, hash, sql)
info = self.ds.inspect()[name] info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {}) metadata = self.ds.metadata.get("databases", {}).get(name, {})
@ -34,7 +34,7 @@ class DatabaseView(BaseView):
"config": self.ds.config, "config": self.ds.config,
}, { }, {
"database_hash": hash, "database_hash": hash,
"show_hidden": request.args.get("_show_hidden"), "show_hidden": qs.first_or_none("_show_hidden"),
"editable": True, "editable": True,
"metadata": metadata, "metadata": metadata,
}, ( }, (
@ -44,7 +44,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView): class DatabaseDownload(BaseView):
async def view_get(self, request, name, hash, **kwargs): async def view_get(self, qs, name, hash, **kwargs):
if not self.ds.config["allow_download"]: if not self.ds.config["allow_download"]:
raise DatasetteError("Database download is forbidden", status=403) raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"] filepath = self.ds.inspect()[name]["file"]

View file

@ -1,4 +1,3 @@
from collections import namedtuple
import sqlite3 import sqlite3
import urllib import urllib
@ -220,11 +219,11 @@ class RowTableShared(BaseView):
class TableView(RowTableShared): class TableView(RowTableShared):
async def data(self, request, name, hash, table, default_labels=False): async def data(self, qs, name, hash, table, default_labels=False):
canned_query = self.ds.get_canned_query(name, table) canned_query = self.ds.get_canned_query(name, table)
if canned_query is not None: if canned_query is not None:
return await self.custom_sql( return await self.custom_sql(
request, qs,
name, name,
hash, hash,
canned_query["sql"], canned_query["sql"],
@ -255,7 +254,7 @@ class TableView(RowTableShared):
# We roll our own query_string decoder because by default Sanic # We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact= # drops anything with an empty value e.g. ?name__exact=
args = RequestParameters( args = RequestParameters(
urllib.parse.parse_qs(request.query_string, keep_blank_values=True) urllib.parse.parse_qs(str(qs), keep_blank_values=True)
) )
# Special args start with _ and do not contain a __ # Special args start with _ and do not contain a __
@ -275,17 +274,17 @@ class TableView(RowTableShared):
redirect_params = filters_should_redirect(special_args) redirect_params = filters_should_redirect(special_args)
if redirect_params: if redirect_params:
return self.redirect( return self.redirect(
request, qs,
path_with_added_args(request, redirect_params), path_with_added_args(qs, redirect_params),
forward_querystring=False, forward_querystring=False,
) )
# Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort)
if "_sort_by_desc" in special_args: if "_sort_by_desc" in special_args:
return self.redirect( return self.redirect(
request, qs,
path_with_added_args( path_with_added_args(
request, qs,
{ {
"_sort_desc": special_args.get("_sort"), "_sort_desc": special_args.get("_sort"),
"_sort_by_desc": None, "_sort_by_desc": None,
@ -458,11 +457,11 @@ class TableView(RowTableShared):
table_name=escape_sqlite(table), table_name=escape_sqlite(table),
where=where_clause, where=where_clause,
) )
return await self.custom_sql(request, name, hash, sql, editable=True) return await self.custom_sql(qs, name, hash, sql, editable=True)
extra_args = {} extra_args = {}
# Handle ?_size=500 # Handle ?_size=500
page_size = request.raw_args.get("_size") page_size = qs.first_or_none("_size")
if page_size: if page_size:
if page_size == "max": if page_size == "max":
page_size = self.max_returned_rows page_size = self.max_returned_rows
@ -492,8 +491,8 @@ class TableView(RowTableShared):
offset=offset, offset=offset,
) )
if request.raw_args.get("_timelimit"): if qs.first_or_none("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) extra_args["custom_time_limit"] = int(qs.first("_timelimit"))
results = await self.ds.execute( results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args name, sql, params, truncate=True, **extra_args
@ -503,10 +502,10 @@ class TableView(RowTableShared):
facet_size = self.ds.config["default_facet_size"] facet_size = self.ds.config["default_facet_size"]
metadata_facets = table_metadata.get("facets", []) metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:] facets = metadata_facets[:]
if request.args.get("_facet") and not self.ds.config["allow_facet"]: if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]:
raise DatasetteError("_facet= is not allowed", status=400) raise DatasetteError("_facet= is not allowed", status=400)
try: try:
facets.extend(request.args["_facet"]) facets.extend(qs.getlist("_facet"))
except KeyError: except KeyError:
pass pass
facet_results = {} facet_results = {}
@ -544,11 +543,11 @@ class TableView(RowTableShared):
selected = str(other_args.get(column)) == str(row["value"]) selected = str(other_args.get(column)) == str(row["value"])
if selected: if selected:
toggle_path = path_with_removed_args( toggle_path = path_with_removed_args(
request, {column: str(row["value"])} qs, {column: str(row["value"])}
) )
else: else:
toggle_path = path_with_added_args( toggle_path = path_with_added_args(
request, {column: row["value"]} qs, {column: row["value"]}
) )
facet_results_values.append({ facet_results_values.append({
"value": row["value"], "value": row["value"],
@ -558,7 +557,7 @@ class TableView(RowTableShared):
), ),
"count": row["count"], "count": row["count"],
"toggle_url": urllib.parse.urljoin( "toggle_url": urllib.parse.urljoin(
request.url, toggle_path qs.path, toggle_path
), ),
"selected": selected, "selected": selected,
}) })
@ -581,8 +580,8 @@ class TableView(RowTableShared):
except ValueError: except ValueError:
all_labels = default_labels all_labels = default_labels
# Check for explicit _label= # Check for explicit _label=
if "_label" in request.args: if qs.first_or_none("_label"):
columns_to_expand = request.args["_label"] columns_to_expand = qs.first("_label")
if columns_to_expand is None and all_labels: if columns_to_expand is None and all_labels:
# expand all columns with foreign keys # expand all columns with foreign keys
columns_to_expand = [ columns_to_expand = [
@ -644,7 +643,7 @@ class TableView(RowTableShared):
else: else:
added_args = {"_next": next_value} added_args = {"_next": next_value}
next_url = urllib.parse.urljoin( next_url = urllib.parse.urljoin(
request.url, path_with_replaced_args(request, added_args) qs.path, path_with_replaced_args(qs, added_args)
) )
rows = rows[:page_size] rows = rows[:page_size]
@ -694,7 +693,7 @@ class TableView(RowTableShared):
suggested_facets.append({ suggested_facets.append({
'name': facet_column, 'name': facet_column,
'toggle_url': path_with_added_args( 'toggle_url': path_with_added_args(
request, {'_facet': facet_column} qs, {'_facet': facet_column}
), ),
}) })
except InterruptedError: except InterruptedError:
@ -744,7 +743,7 @@ class TableView(RowTableShared):
"is_sortable": any(c["sortable"] for c in display_columns), "is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args, "path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args, "path_with_removed_args": path_with_removed_args,
"request": request, "qs": qs,
"sort": sort, "sort": sort,
"sort_desc": sort_desc, "sort_desc": sort_desc,
"disable_sort": is_view, "disable_sort": is_view,
@ -788,7 +787,7 @@ class TableView(RowTableShared):
class RowView(RowTableShared): class RowView(RowTableShared):
async def data(self, request, name, hash, table, pk_path, default_labels=False): async def data(self, qs, name, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path) pk_values = urlsafe_components(pk_path)
info = self.ds.inspect()[name] info = self.ds.inspect()[name]
table_info = info["tables"].get(table) or {} table_info = info["tables"].get(table) or {}
@ -854,7 +853,7 @@ class RowView(RowTableShared):
"units": self.table_metadata(name, table).get("units", {}), "units": self.table_metadata(name, table).get("units", {}),
} }
if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","): if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables( data["foreign_key_tables"] = await self.foreign_key_tables(
name, table, pk_values name, table, pk_values
) )

View file

@ -6,7 +6,6 @@ from datasette import utils
import json import json
import os import os
import pytest import pytest
from sanic.request import Request
import sqlite3 import sqlite3
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
@ -40,11 +39,12 @@ def test_urlsafe_components(path, expected):
), '/?_facet=state&_facet=city&_facet=planet_int'), ), '/?_facet=state&_facet=city&_facet=planet_int'),
]) ])
def test_path_with_added_args(path, added_args, expected): def test_path_with_added_args(path, added_args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_added_args(request, added_args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_added_args(qs, added_args)
assert expected == actual assert expected == actual
@ -54,11 +54,12 @@ def test_path_with_added_args(path, added_args, expected):
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'), ('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
]) ])
def test_path_with_removed_args(path, args, expected): def test_path_with_removed_args(path, args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_removed_args(request, args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_removed_args(qs, args)
assert expected == actual assert expected == actual
@ -67,11 +68,12 @@ def test_path_with_removed_args(path, args, expected):
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'), ('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
]) ])
def test_path_with_replaced_args(path, args, expected): def test_path_with_replaced_args(path, args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_replaced_args(request, args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_replaced_args(qs, args)
assert expected == actual assert expected == actual
@ -344,9 +346,10 @@ def test_resolve_table_and_format(
], ],
) )
def test_path_with_format(path, format, extra_qs, expected): def test_path_with_format(path, format, extra_qs, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_format(request, format, extra_qs) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_format(qs, format, extra_qs)
assert expected == actual assert expected == actual