Refactored everything to use new qs instead of request

I don't like this, I think I will go back to the request object but with my
own custom request object that has a request.qs property.
This commit is contained in:
Simon Willison 2018-06-16 21:45:30 -07:00
commit 68302b6ca2
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
7 changed files with 113 additions and 103 deletions

View file

@ -7,9 +7,9 @@
{{ column.name }} {{ column.name }}
{% else %} {% else %}
{% if column.name == sort %} {% if column.name == sort %}
<a href="{{ path_with_replaced_args(request, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a> <a href="{{ path_with_replaced_args(qs, {'_sort_desc': column.name, '_sort': None, '_next': None}) }}" rel="nofollow">{{ column.name }}&nbsp;</a>
{% else %} {% else %}
<a href="{{ path_with_replaced_args(request, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a> <a href="{{ path_with_replaced_args(qs, {'_sort': column.name, '_sort_desc': None, '_next': None}) }}" rel="nofollow">{{ column.name }}{% if column.name == sort_desc %}&nbsp;▲{% endif %}</a>
{% endif %} {% endif %}
{% endif %} {% endif %}
</th> </th>

View file

@ -111,7 +111,7 @@
<p class="facet-info-name"> <p class="facet-info-name">
<strong>{{ facet_info.name }}</strong> <strong>{{ facet_info.name }}</strong>
{% if facet_hideable(facet_info.name) %} {% if facet_hideable(facet_info.name) %}
<a href="{{ path_with_removed_args(request, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a> <a href="{{ path_with_removed_args(qs, {'_facet': facet_info['name']}) }}" class="cross">&#x2716;</a>
{% endif %} {% endif %}
</p> </p>
<ul> <ul>

View file

@ -170,13 +170,13 @@ def validate_sql_select(sql):
raise InvalidSql(msg) raise InvalidSql(msg)
def path_with_added_args(request, args, path=None): def path_with_added_args(qs, args, path=None):
path = path or request.path path = path or qs.path
if isinstance(args, dict): if isinstance(args, dict):
args = args.items() args = args.items()
args_to_remove = {k for k, v in args if v is None} args_to_remove = {k for k, v in args if v is None}
current = [] current = []
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in args_to_remove: if key not in args_to_remove:
current.append((key, value)) current.append((key, value))
current.extend([ current.extend([
@ -190,9 +190,9 @@ def path_with_added_args(request, args, path=None):
return path + query_string return path + query_string
def path_with_removed_args(request, args, path=None): def path_with_removed_args(qs, args, path=None):
# args can be a dict or a set # args can be a dict or a set
path = path or request.path path = path or qs.path
current = [] current = []
if isinstance(args, set): if isinstance(args, set):
def should_remove(key, value): def should_remove(key, value):
@ -201,7 +201,7 @@ def path_with_removed_args(request, args, path=None):
# Must match key AND value # Must match key AND value
def should_remove(key, value): def should_remove(key, value):
return args.get(key) == value return args.get(key) == value
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if not should_remove(key, value): if not should_remove(key, value):
current.append((key, value)) current.append((key, value))
query_string = urllib.parse.urlencode(current) query_string = urllib.parse.urlencode(current)
@ -210,13 +210,13 @@ def path_with_removed_args(request, args, path=None):
return path + query_string return path + query_string
def path_with_replaced_args(request, args, path=None): def path_with_replaced_args(qs, args, path=None):
path = path or request.path path = path or qs.path
if isinstance(args, dict): if isinstance(args, dict):
args = args.items() args = args.items()
keys_to_replace = {p[0] for p in args} keys_to_replace = {p[0] for p in args}
current = [] current = []
for key, value in urllib.parse.parse_qsl(request.query_string): for key, value in urllib.parse.parse_qsl(str(qs)):
if key not in keys_to_replace: if key not in keys_to_replace:
current.append((key, value)) current.append((key, value))
current.extend([p for p in args if p[1] is not None]) current.extend([p for p in args if p[1] is not None])
@ -783,23 +783,23 @@ def resolve_table_and_format(table_and_format, table_exists):
return table_and_format, None return table_and_format, None
def path_with_format(request, format, extra_qs=None): def path_with_format(qs, format, extra_qs=None):
qs = extra_qs or {} new_qs = extra_qs or {}
path = request.path path = qs.path
if "." in request.path: if "." in qs.path:
qs["_format"] = format new_qs["_format"] = format
else: else:
path = "{}.{}".format(path, format) path = "{}.{}".format(path, format)
if qs: if new_qs:
extra = urllib.parse.urlencode(sorted(qs.items())) extra = urllib.parse.urlencode(sorted(new_qs.items()))
if request.query_string: if qs.data:
path = "{}?{}&{}".format( path = "{}?{}&{}".format(
path, request.query_string, extra path, str(qs), extra
) )
else: else:
path = "{}?{}".format(path, extra) path = "{}?{}".format(path, extra)
elif request.query_string: elif qs.data:
path = "{}?{}".format(path, request.query_string) path = "{}?{}".format(path, str(qs))
return path return path
@ -831,11 +831,12 @@ class ValueAsBooleanError(ValueError):
class Querystring: class Querystring:
def __init__(self, qs=None): def __init__(self, path, qs=None):
self.path = path
self.prev = None self.prev = None
self.data = [] self.data = []
if qs: if qs:
self.data = urllib.parse.parse_qsl(qs) self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True)
def first(self, key): def first(self, key):
for item in self.data: for item in self.data:
@ -843,6 +844,12 @@ class Querystring:
return item[1] return item[1]
raise KeyError raise KeyError
def first_or_none(self, key):
try:
return self.first(key)
except KeyError:
return None
def last(self, key): def last(self, key):
for item in reversed(self.data): for item in reversed(self.data):
if item[0] == key: if item[0] == key:
@ -854,10 +861,11 @@ class Querystring:
for item in self.data: for item in self.data:
if item[0] == key: if item[0] == key:
result.append(item[1]) result.append(item[1])
if not result:
raise KeyError
return result return result
def first_dict(self):
return {k: v[0] for k, v in self.data}
def append(self, key, value): def append(self, key, value):
self.data.append((key, value)) self.data.append((key, value))

View file

@ -16,6 +16,7 @@ from datasette.utils import (
CustomJSONEncoder, CustomJSONEncoder,
InterruptedError, InterruptedError,
InvalidSql, InvalidSql,
Querystring,
path_from_row_pks, path_from_row_pks,
path_with_added_args, path_with_added_args,
path_with_format, path_with_format,
@ -85,9 +86,9 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*" r.headers["Access-Control-Allow-Origin"] = "*"
return r return r
def redirect(self, request, path, forward_querystring=True): def redirect(self, qs, path, forward_querystring=True):
if request.query_string and "?" not in path and forward_querystring: if qs.data and "?" not in qs.path and forward_querystring:
path = "{}?{}".format(path, request.query_string) path = "{}?{}".format(path, str(qs))
r = response.redirect(path) r = response.redirect(path)
r.headers["Link"] = "<{}>; rel=preload".format(path) r.headers["Link"] = "<{}>; rel=preload".format(path)
if self.ds.cors: if self.ds.cors:
@ -144,15 +145,15 @@ class BaseView(RenderMixin):
async def get(self, request, db_name, **kwargs): async def get(self, request, db_name, **kwargs):
name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs)
qs = Querystring(request.path, request.query_string)
if should_redirect: if should_redirect:
return self.redirect(request, should_redirect) return self.redirect(qs, should_redirect)
return await self.view_get(qs, name, hash, **kwargs)
return await self.view_get(request, name, hash, **kwargs) async def as_csv(self, qs, name, hash, **kwargs):
async def as_csv(self, request, name, hash, **kwargs):
try: try:
response_or_template_contexts = await self.data( response_or_template_contexts = await self.data(
request, name, hash, **kwargs qs, name, hash, **kwargs
) )
if isinstance(response_or_template_contexts, response.HTTPResponse): if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts return response_or_template_contexts
@ -198,7 +199,7 @@ class BaseView(RenderMixin):
content_type = "text/plain; charset=utf-8" content_type = "text/plain; charset=utf-8"
headers = {} headers = {}
if request.args.get("_dl", None): if qs.first_or_none("_dl"):
content_type = "text/csv; charset=utf-8" content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format( disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get('table', name) kwargs.get('table', name)
@ -211,9 +212,9 @@ class BaseView(RenderMixin):
content_type=content_type content_type=content_type
) )
async def view_get(self, request, name, hash, **kwargs): async def view_get(self, qs, name, hash, **kwargs):
# If ?_format= is provided, use that as the format # If ?_format= is provided, use that as the format
_format = request.args.get("_format", None) _format = qs.first_or_none("_format")
if not _format: if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".") _format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs: if "table_and_format" in kwargs:
@ -228,7 +229,7 @@ class BaseView(RenderMixin):
del kwargs["table_and_format"] del kwargs["table_and_format"]
if _format == "csv": if _format == "csv":
return await self.as_csv(request, name, hash, **kwargs) return await self.as_csv(qs, name, hash, **kwargs)
if _format is None: if _format is None:
# HTML views default to expanding all forign key labels # HTML views default to expanding all forign key labels
@ -240,7 +241,7 @@ class BaseView(RenderMixin):
templates = [] templates = []
try: try:
response_or_template_contexts = await self.data( response_or_template_contexts = await self.data(
request, name, hash, **kwargs qs, name, hash, **kwargs
) )
if isinstance(response_or_template_contexts, response.HTTPResponse): if isinstance(response_or_template_contexts, response.HTTPResponse):
return response_or_template_contexts return response_or_template_contexts
@ -272,26 +273,25 @@ class BaseView(RenderMixin):
# Special case for .jsono extension - redirect to _shape=objects # Special case for .jsono extension - redirect to _shape=objects
if _format == "jsono": if _format == "jsono":
return self.redirect( return self.redirect(
request, qs,
path_with_added_args( path_with_added_args(
request, qs,
{"_shape": "objects"}, {"_shape": "objects"},
path=request.path.rsplit(".jsono", 1)[0] + ".json", path=qs.path.rsplit(".jsono", 1)[0] + ".json",
), ),
forward_querystring=False, forward_querystring=False,
) )
# Handle the _json= parameter which may modify data["rows"] # Handle the _json= parameter which may modify data["rows"]
json_cols = [] json_cols = []
if "_json" in request.args: json_cols = qs.getlist("_json")
json_cols = request.args["_json"]
if json_cols and "rows" in data and "columns" in data: if json_cols and "rows" in data and "columns" in data:
data["rows"] = convert_specific_columns_to_json( data["rows"] = convert_specific_columns_to_json(
data["rows"], data["columns"], json_cols, data["rows"], data["columns"], json_cols,
) )
# Deal with the _shape option # Deal with the _shape option
shape = request.args.get("_shape", "arrays") shape = qs.first_or_none("_shape") or "arrays"
if shape == "arrayfirst": if shape == "arrayfirst":
data = [row[0] for row in data["rows"]] data = [row[0] for row in data["rows"]]
elif shape in ("objects", "object", "array"): elif shape in ("objects", "object", "array"):
@ -353,11 +353,11 @@ class BaseView(RenderMixin):
**data, **data,
**extras, **extras,
**{ **{
"url_json": path_with_format(request, "json"), "url_json": path_with_format(qs, "json"),
"url_csv": path_with_format(request, "csv", { "url_csv": path_with_format(qs, "csv", {
"_size": "max" "_size": "max"
}), }),
"url_csv_dl": path_with_format(request, "csv", { "url_csv_dl": path_with_format(qs, "csv", {
"_dl": "1", "_dl": "1",
"_size": "max" "_size": "max"
}), }),
@ -372,7 +372,7 @@ class BaseView(RenderMixin):
r.status = status_code r.status = status_code
# Set far-future cache expiry # Set far-future cache expiry
if self.ds.cache_headers: if self.ds.cache_headers:
ttl = request.args.get("_ttl", None) ttl = qs.first_or_none("_ttl")
if ttl is None or not ttl.isdigit(): if ttl is None or not ttl.isdigit():
ttl = self.ds.config["default_cache_ttl"] ttl = self.ds.config["default_cache_ttl"]
else: else:
@ -386,9 +386,9 @@ class BaseView(RenderMixin):
return r return r
async def custom_sql( async def custom_sql(
self, request, name, hash, sql, editable=True, canned_query=None self, qs, name, hash, sql, editable=True, canned_query=None
): ):
params = request.raw_args params = qs.first_dict()
if "sql" in params: if "sql" in params:
params.pop("sql") params.pop("sql")
if "_shape" in params: if "_shape" in params:

View file

@ -9,13 +9,13 @@ from .base import BaseView, DatasetteError
class DatabaseView(BaseView): class DatabaseView(BaseView):
async def data(self, request, name, hash, default_labels=False): async def data(self, qs, name, hash, default_labels=False):
if request.args.get("sql"): if qs.first_or_none("sql"):
if not self.ds.config["allow_sql"]: if not self.ds.config["allow_sql"]:
raise DatasetteError("sql= is not allowed", status=400) raise DatasetteError("sql= is not allowed", status=400)
sql = request.raw_args.pop("sql") sql = qs.first("sql")
validate_sql_select(sql) validate_sql_select(sql)
return await self.custom_sql(request, name, hash, sql) return await self.custom_sql(qs, name, hash, sql)
info = self.ds.inspect()[name] info = self.ds.inspect()[name]
metadata = self.ds.metadata.get("databases", {}).get(name, {}) metadata = self.ds.metadata.get("databases", {}).get(name, {})
@ -34,7 +34,7 @@ class DatabaseView(BaseView):
"config": self.ds.config, "config": self.ds.config,
}, { }, {
"database_hash": hash, "database_hash": hash,
"show_hidden": request.args.get("_show_hidden"), "show_hidden": qs.first_or_none("_show_hidden"),
"editable": True, "editable": True,
"metadata": metadata, "metadata": metadata,
}, ( }, (
@ -44,7 +44,7 @@ class DatabaseView(BaseView):
class DatabaseDownload(BaseView): class DatabaseDownload(BaseView):
async def view_get(self, request, name, hash, **kwargs): async def view_get(self, qs, name, hash, **kwargs):
if not self.ds.config["allow_download"]: if not self.ds.config["allow_download"]:
raise DatasetteError("Database download is forbidden", status=403) raise DatasetteError("Database download is forbidden", status=403)
filepath = self.ds.inspect()[name]["file"] filepath = self.ds.inspect()[name]["file"]

View file

@ -1,4 +1,3 @@
from collections import namedtuple
import sqlite3 import sqlite3
import urllib import urllib
@ -220,11 +219,11 @@ class RowTableShared(BaseView):
class TableView(RowTableShared): class TableView(RowTableShared):
async def data(self, request, name, hash, table, default_labels=False): async def data(self, qs, name, hash, table, default_labels=False):
canned_query = self.ds.get_canned_query(name, table) canned_query = self.ds.get_canned_query(name, table)
if canned_query is not None: if canned_query is not None:
return await self.custom_sql( return await self.custom_sql(
request, qs,
name, name,
hash, hash,
canned_query["sql"], canned_query["sql"],
@ -255,7 +254,7 @@ class TableView(RowTableShared):
# We roll our own query_string decoder because by default Sanic # We roll our own query_string decoder because by default Sanic
# drops anything with an empty value e.g. ?name__exact= # drops anything with an empty value e.g. ?name__exact=
args = RequestParameters( args = RequestParameters(
urllib.parse.parse_qs(request.query_string, keep_blank_values=True) urllib.parse.parse_qs(str(qs), keep_blank_values=True)
) )
# Special args start with _ and do not contain a __ # Special args start with _ and do not contain a __
@ -275,17 +274,17 @@ class TableView(RowTableShared):
redirect_params = filters_should_redirect(special_args) redirect_params = filters_should_redirect(special_args)
if redirect_params: if redirect_params:
return self.redirect( return self.redirect(
request, qs,
path_with_added_args(request, redirect_params), path_with_added_args(qs, redirect_params),
forward_querystring=False, forward_querystring=False,
) )
# Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort)
if "_sort_by_desc" in special_args: if "_sort_by_desc" in special_args:
return self.redirect( return self.redirect(
request, qs,
path_with_added_args( path_with_added_args(
request, qs,
{ {
"_sort_desc": special_args.get("_sort"), "_sort_desc": special_args.get("_sort"),
"_sort_by_desc": None, "_sort_by_desc": None,
@ -458,11 +457,11 @@ class TableView(RowTableShared):
table_name=escape_sqlite(table), table_name=escape_sqlite(table),
where=where_clause, where=where_clause,
) )
return await self.custom_sql(request, name, hash, sql, editable=True) return await self.custom_sql(qs, name, hash, sql, editable=True)
extra_args = {} extra_args = {}
# Handle ?_size=500 # Handle ?_size=500
page_size = request.raw_args.get("_size") page_size = qs.first_or_none("_size")
if page_size: if page_size:
if page_size == "max": if page_size == "max":
page_size = self.max_returned_rows page_size = self.max_returned_rows
@ -492,8 +491,8 @@ class TableView(RowTableShared):
offset=offset, offset=offset,
) )
if request.raw_args.get("_timelimit"): if qs.first_or_none("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) extra_args["custom_time_limit"] = int(qs.first("_timelimit"))
results = await self.ds.execute( results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args name, sql, params, truncate=True, **extra_args
@ -503,10 +502,10 @@ class TableView(RowTableShared):
facet_size = self.ds.config["default_facet_size"] facet_size = self.ds.config["default_facet_size"]
metadata_facets = table_metadata.get("facets", []) metadata_facets = table_metadata.get("facets", [])
facets = metadata_facets[:] facets = metadata_facets[:]
if request.args.get("_facet") and not self.ds.config["allow_facet"]: if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]:
raise DatasetteError("_facet= is not allowed", status=400) raise DatasetteError("_facet= is not allowed", status=400)
try: try:
facets.extend(request.args["_facet"]) facets.extend(qs.getlist("_facet"))
except KeyError: except KeyError:
pass pass
facet_results = {} facet_results = {}
@ -544,11 +543,11 @@ class TableView(RowTableShared):
selected = str(other_args.get(column)) == str(row["value"]) selected = str(other_args.get(column)) == str(row["value"])
if selected: if selected:
toggle_path = path_with_removed_args( toggle_path = path_with_removed_args(
request, {column: str(row["value"])} qs, {column: str(row["value"])}
) )
else: else:
toggle_path = path_with_added_args( toggle_path = path_with_added_args(
request, {column: row["value"]} qs, {column: row["value"]}
) )
facet_results_values.append({ facet_results_values.append({
"value": row["value"], "value": row["value"],
@ -558,7 +557,7 @@ class TableView(RowTableShared):
), ),
"count": row["count"], "count": row["count"],
"toggle_url": urllib.parse.urljoin( "toggle_url": urllib.parse.urljoin(
request.url, toggle_path qs.path, toggle_path
), ),
"selected": selected, "selected": selected,
}) })
@ -581,8 +580,8 @@ class TableView(RowTableShared):
except ValueError: except ValueError:
all_labels = default_labels all_labels = default_labels
# Check for explicit _label= # Check for explicit _label=
if "_label" in request.args: if qs.first_or_none("_label"):
columns_to_expand = request.args["_label"] columns_to_expand = qs.first("_label")
if columns_to_expand is None and all_labels: if columns_to_expand is None and all_labels:
# expand all columns with foreign keys # expand all columns with foreign keys
columns_to_expand = [ columns_to_expand = [
@ -644,7 +643,7 @@ class TableView(RowTableShared):
else: else:
added_args = {"_next": next_value} added_args = {"_next": next_value}
next_url = urllib.parse.urljoin( next_url = urllib.parse.urljoin(
request.url, path_with_replaced_args(request, added_args) qs.path, path_with_replaced_args(qs, added_args)
) )
rows = rows[:page_size] rows = rows[:page_size]
@ -694,7 +693,7 @@ class TableView(RowTableShared):
suggested_facets.append({ suggested_facets.append({
'name': facet_column, 'name': facet_column,
'toggle_url': path_with_added_args( 'toggle_url': path_with_added_args(
request, {'_facet': facet_column} qs, {'_facet': facet_column}
), ),
}) })
except InterruptedError: except InterruptedError:
@ -744,7 +743,7 @@ class TableView(RowTableShared):
"is_sortable": any(c["sortable"] for c in display_columns), "is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args, "path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args, "path_with_removed_args": path_with_removed_args,
"request": request, "qs": qs,
"sort": sort, "sort": sort,
"sort_desc": sort_desc, "sort_desc": sort_desc,
"disable_sort": is_view, "disable_sort": is_view,
@ -788,7 +787,7 @@ class TableView(RowTableShared):
class RowView(RowTableShared): class RowView(RowTableShared):
async def data(self, request, name, hash, table, pk_path, default_labels=False): async def data(self, qs, name, hash, table, pk_path, default_labels=False):
pk_values = urlsafe_components(pk_path) pk_values = urlsafe_components(pk_path)
info = self.ds.inspect()[name] info = self.ds.inspect()[name]
table_info = info["tables"].get(table) or {} table_info = info["tables"].get(table) or {}
@ -854,7 +853,7 @@ class RowView(RowTableShared):
"units": self.table_metadata(name, table).get("units", {}), "units": self.table_metadata(name, table).get("units", {}),
} }
if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","): if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","):
data["foreign_key_tables"] = await self.foreign_key_tables( data["foreign_key_tables"] = await self.foreign_key_tables(
name, table, pk_values name, table, pk_values
) )

View file

@ -6,7 +6,6 @@ from datasette import utils
import json import json
import os import os
import pytest import pytest
from sanic.request import Request
import sqlite3 import sqlite3
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
@ -40,11 +39,12 @@ def test_urlsafe_components(path, expected):
), '/?_facet=state&_facet=city&_facet=planet_int'), ), '/?_facet=state&_facet=city&_facet=planet_int'),
]) ])
def test_path_with_added_args(path, added_args, expected): def test_path_with_added_args(path, added_args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_added_args(request, added_args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_added_args(qs, added_args)
assert expected == actual assert expected == actual
@ -54,11 +54,12 @@ def test_path_with_added_args(path, added_args, expected):
('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'), ('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'),
]) ])
def test_path_with_removed_args(path, args, expected): def test_path_with_removed_args(path, args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_removed_args(request, args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_removed_args(qs, args)
assert expected == actual assert expected == actual
@ -67,11 +68,12 @@ def test_path_with_removed_args(path, args, expected):
('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'), ('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'),
]) ])
def test_path_with_replaced_args(path, args, expected): def test_path_with_replaced_args(path, args, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_replaced_args(request, args) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_replaced_args(qs, args)
assert expected == actual assert expected == actual
@ -344,9 +346,10 @@ def test_resolve_table_and_format(
], ],
) )
def test_path_with_format(path, format, extra_qs, expected): def test_path_with_format(path, format, extra_qs, expected):
request = Request( try:
path.encode('utf8'), path, qsbits = path.split('?', 1)
{}, '1.1', 'GET', None except ValueError:
) qsbits = ''
actual = utils.path_with_format(request, format, extra_qs) qs = utils.Querystring(path, qsbits)
actual = utils.path_with_format(qs, format, extra_qs)
assert expected == actual assert expected == actual