diff --git a/datasette/app.py b/datasette/app.py index 8c5480cf..2907d90e 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1211,6 +1211,13 @@ class DatasetteRouter: return await self.handle_404(request, send) async def handle_404(self, request, send, exception=None): + # If path contains % encoding, redirect to dash encoding + if "%" in request.path: + # Try the same path but with "%" replaced by "-" + # and "-" replaced with "-2D" + new_path = request.path.replace("-", "-2D").replace("%", "-") + await asgi_send_redirect(send, new_path) + return # If URL has a trailing slash, redirect to URL without it path = request.scope.get( "raw_path", request.scope["path"].encode("utf8") diff --git a/datasette/templates/_table.html b/datasette/templates/_table.html index d91a1a57..5332f831 100644 --- a/datasette/templates/_table.html +++ b/datasette/templates/_table.html @@ -4,7 +4,7 @@ {% for column in display_columns %} - + {% if not column.sortable %} {{ column.name }} {% else %} diff --git a/datasette/url_builder.py b/datasette/url_builder.py index 2bcda869..eebfe31e 100644 --- a/datasette/url_builder.py +++ b/datasette/url_builder.py @@ -1,4 +1,4 @@ -from .utils import path_with_format, HASH_LENGTH, PrefixedUrlString +from .utils import dash_encode, path_with_format, HASH_LENGTH, PrefixedUrlString import urllib @@ -31,20 +31,20 @@ class Urls: db = self.ds.databases[database] if self.ds.setting("hash_urls") and db.hash: path = self.path( - f"{urllib.parse.quote(database)}-{db.hash[:HASH_LENGTH]}", format=format + f"{dash_encode(database)}-{db.hash[:HASH_LENGTH]}", format=format ) else: - path = self.path(urllib.parse.quote(database), format=format) + path = self.path(dash_encode(database), format=format) return path def table(self, database, table, format=None): - path = f"{self.database(database)}/{urllib.parse.quote_plus(table)}" + path = f"{self.database(database)}/{dash_encode(table)}" if format is not None: path = path_with_format(path=path, format=format) return PrefixedUrlString(path) def query(self, database, query, format=None): - path = f"{self.database(database)}/{urllib.parse.quote_plus(query)}" + path = f"{self.database(database)}/{dash_encode(query)}" if format is not None: path = path_with_format(path=path, format=format) return PrefixedUrlString(path) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index e17b4d7f..aef85fbb 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -111,12 +111,12 @@ async def await_me_maybe(value: typing.Any) -> typing.Any: def urlsafe_components(token): - """Splits token on commas and URL decodes each component""" - return [urllib.parse.unquote_plus(b) for b in token.split(",")] + """Splits token on commas and dash-decodes each component""" + return [dash_decode(b) for b in token.split(",")] def path_from_row_pks(row, pks, use_rowid, quote=True): - """Generate an optionally URL-quoted unique identifier + """Generate an optionally dash-quoted unique identifier for a row from its primary keys.""" if use_rowid: bits = [row["rowid"]] @@ -125,7 +125,7 @@ def path_from_row_pks(row, pks, use_rowid, quote=True): row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks ] if quote: - bits = [urllib.parse.quote_plus(str(bit)) for bit in bits] + bits = [dash_encode(str(bit)) for bit in bits] else: bits = [str(bit) for bit in bits] @@ -1139,3 +1139,36 @@ def add_cors_headers(headers): headers["Access-Control-Allow-Origin"] = "*" headers["Access-Control-Allow-Headers"] = "Authorization" headers["Access-Control-Expose-Headers"] = "Link" + + +_DASH_ENCODING_SAFE = frozenset( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" + b"abcdefghijklmnopqrstuvwxyz" + b"0123456789_" + # This is the same as Python percent-encoding but I removed + # '.' and '-' and '~' +) + + +class DashEncoder(dict): + # Keeps a cache internally, via __missing__ + def __missing__(self, b): + # Handle a cache miss, store encoded string in cache and return. + res = chr(b) if b in _DASH_ENCODING_SAFE else "-{:02X}".format(b) + self[b] = res + return res + + +_dash_encoder = DashEncoder().__getitem__ + + +@documented +def dash_encode(s: str) -> str: + "Returns dash-encoded string - for example ``/foo/bar`` -> ``-2Ffoo-2Fbar``" + return "".join(_dash_encoder(char) for char in s.encode("utf-8")) + + +@documented +def dash_decode(s: str) -> str: + "Decodes a dash-encoded string, so ``-2Ffoo-2Fbar`` -> ``/foo/bar``" + return urllib.parse.unquote(s.replace("-", "%")) diff --git a/datasette/views/base.py b/datasette/views/base.py index c74d6141..7cd385b7 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -17,6 +17,8 @@ from datasette.utils import ( InvalidSql, LimitedWriter, call_with_supported_arguments, + dash_decode, + dash_encode, path_from_row_pks, path_with_added_args, path_with_removed_args, @@ -203,17 +205,17 @@ class DataView(BaseView): async def resolve_db_name(self, request, db_name, **kwargs): hash = None name = None - db_name = urllib.parse.unquote_plus(db_name) - if db_name not in self.ds.databases and "-" in db_name: + decoded_name = dash_decode(db_name) + if decoded_name not in self.ds.databases and "-" in db_name: # No matching DB found, maybe it's a name-hash? name_bit, hash_bit = db_name.rsplit("-", 1) - if name_bit not in self.ds.databases: + if dash_decode(name_bit) not in self.ds.databases: raise NotFound(f"Database not found: {name}") else: - name = name_bit + name = dash_decode(name_bit) hash = hash_bit else: - name = db_name + name = decoded_name try: db = self.ds.databases[name] @@ -233,9 +235,7 @@ class DataView(BaseView): return await db.table_exists(t) table, _format = await resolve_table_and_format( - table_and_format=urllib.parse.unquote_plus( - kwargs["table_and_format"] - ), + table_and_format=dash_decode(kwargs["table_and_format"]), table_exists=async_table_exists, allowed_formats=self.ds.renderers.keys(), ) @@ -243,11 +243,11 @@ class DataView(BaseView): if _format: kwargs["as_format"] = f".{_format}" elif kwargs.get("table"): - kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"]) + kwargs["table"] = dash_decode(kwargs["table"]) should_redirect = self.ds.urls.path(f"{name}-{expected}") if kwargs.get("table"): - should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) + should_redirect += "/" + dash_encode(kwargs["table"]) if kwargs.get("pk_path"): should_redirect += "/" + kwargs["pk_path"] if kwargs.get("as_format"): @@ -467,7 +467,7 @@ class DataView(BaseView): return await db.table_exists(t) table, _ext_format = await resolve_table_and_format( - table_and_format=urllib.parse.unquote_plus(args["table_and_format"]), + table_and_format=dash_decode(args["table_and_format"]), table_exists=async_table_exists, allowed_formats=self.ds.renderers.keys(), ) @@ -475,7 +475,7 @@ class DataView(BaseView): args["table"] = table del args["table_and_format"] elif "table" in args: - args["table"] = urllib.parse.unquote_plus(args["table"]) + args["table"] = dash_decode(args["table"]) return _format, args async def view_get(self, request, database, hash, correct_hash_provided, **kwargs): diff --git a/datasette/views/table.py b/datasette/views/table.py index be9e9c3b..1d81755e 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -12,6 +12,7 @@ from datasette.utils import ( MultiParams, append_querystring, compound_keys_after_sql, + dash_encode, escape_sqlite, filters_should_redirect, is_url, @@ -142,7 +143,7 @@ class RowTableShared(DataView): '{flat_pks}'.format( base_url=base_url, database=database, - table=urllib.parse.quote_plus(table), + table=dash_encode(table), flat_pks=str(markupsafe.escape(pk_path)), flat_pks_quoted=path_from_row_pks(row, pks, not pks), ) @@ -199,8 +200,8 @@ class RowTableShared(DataView): link_template.format( database=database, base_url=base_url, - table=urllib.parse.quote_plus(other_table), - link_id=urllib.parse.quote_plus(str(value)), + table=dash_encode(other_table), + link_id=dash_encode(str(value)), id=str(markupsafe.escape(value)), label=str(markupsafe.escape(label)) or "-", ) @@ -765,7 +766,7 @@ class TableView(RowTableShared): if prefix is None: prefix = "$null" else: - prefix = urllib.parse.quote_plus(str(prefix)) + prefix = dash_encode(str(prefix)) next_value = f"{prefix},{next_value}" added_args = {"_next": next_value} if sort: diff --git a/docs/internals.rst b/docs/internals.rst index 12ef5c54..d035e1f1 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -876,6 +876,32 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth .. autofunction:: datasette.utils.await_me_maybe +.. _internals_dash_encoding: + +Dash encoding +------------- + +Datasette uses a custom encoding scheme in some places, called **dash encoding**. This is primarily used for table names and row primary keys, to avoid any confusion between ``/`` characters in those values and the Datasette URLs that reference them. + +Dash encoding uses the same algorithm as `URL percent-encoding `__, but with the ``-`` hyphen character used in place of ``%``. + +Any character other than ``ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz 0123456789_`` will be replaced by the numeric equivalent preceded by a hyphen. For example: + +- ``/`` becomes ``-2F`` +- ``.`` becomes ``-2E`` +- ``%`` becomes ``-25`` +- ``-`` becomes ``-2D`` +- Space character becomes ``-20`` +- ``polls/2022.primary`` becomes ``polls-2F2022-2Eprimary`` + +.. _internals_utils_dash_encode: + +.. autofunction:: datasette.utils.dash_encode + +.. _internals_utils_dash_decode: + +.. autofunction:: datasette.utils.dash_decode + .. _internals_tracer: datasette.tracer diff --git a/tests/fixtures.py b/tests/fixtures.py index 26f0cf7b..11f09c41 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -406,6 +406,7 @@ CREATE TABLE compound_primary_key ( ); INSERT INTO compound_primary_key VALUES ('a', 'b', 'c'); +INSERT INTO compound_primary_key VALUES ('a/b', '.c-d', 'c'); CREATE TABLE compound_three_primary_keys ( pk1 varchar(30), diff --git a/tests/test_api.py b/tests/test_api.py index 57471af2..dd916cf0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -143,7 +143,7 @@ def test_database_page(app_client): "name": "compound_primary_key", "columns": ["pk1", "pk2", "content"], "primary_keys": ["pk1", "pk2"], - "count": 1, + "count": 2, "hidden": False, "fts_table": None, "foreign_keys": {"incoming": [], "outgoing": []}, @@ -942,7 +942,7 @@ def test_cors(app_client_with_cors, path, status_code): ) def test_database_with_space_in_name(app_client_two_attached_databases, path): response = app_client_two_attached_databases.get( - "/extra database" + path, follow_redirects=True + "/extra-20database" + path, follow_redirects=True ) assert response.status == 200 @@ -953,7 +953,7 @@ def test_common_prefix_database_names(app_client_conflicting_database_names): d["name"] for d in app_client_conflicting_database_names.get("/-/databases.json").json ] - for db_name, path in (("foo", "/foo.json"), ("foo-bar", "/foo-bar.json")): + for db_name, path in (("foo", "/foo.json"), ("foo-bar", "/foo-2Dbar.json")): data = app_client_conflicting_database_names.get(path).json assert db_name == data["database"] @@ -992,3 +992,16 @@ async def test_hidden_sqlite_stat1_table(): data = (await ds.client.get("/db.json?_show_hidden=1")).json() tables = [(t["name"], t["hidden"]) for t in data["tables"]] assert tables == [("normal", False), ("sqlite_stat1", True)] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_name", ("foo", r"fo%o", "f~/c.d")) +async def test_dash_encoded_database_names(db_name): + ds = Datasette() + ds.add_memory_database(db_name) + response = await ds.client.get("/.json") + assert db_name in response.json().keys() + path = response.json()[db_name]["path"] + # And the JSON for that database + response2 = await ds.client.get(path + ".json") + assert response2.status_code == 200 diff --git a/tests/test_cli.py b/tests/test_cli.py index 3fbfdee2..e30c2ad3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,7 @@ from datasette.app import SETTINGS from datasette.plugins import DEFAULT_PLUGINS from datasette.cli import cli, serve from datasette.version import __version__ +from datasette.utils import dash_encode from datasette.utils.sqlite import sqlite3 from click.testing import CliRunner import io @@ -294,12 +295,12 @@ def test_weird_database_names(ensure_eventloop, tmpdir, filename): assert result1.exit_code == 0, result1.output filename_no_stem = filename.rsplit(".", 1)[0] expected_link = '{}'.format( - urllib.parse.quote(filename_no_stem), filename_no_stem + dash_encode(filename_no_stem), filename_no_stem ) assert expected_link in result1.output # Now try hitting that database page result2 = runner.invoke( - cli, [db_path, "--get", "/{}".format(urllib.parse.quote(filename_no_stem))] + cli, [db_path, "--get", "/{}".format(dash_encode(filename_no_stem))] ) assert result2.exit_code == 0, result2.output diff --git a/tests/test_html.py b/tests/test_html.py index d5f4250d..3e24009e 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -29,7 +29,7 @@ def test_homepage(app_client_two_attached_databases): ) # Should be two attached databases assert [ - {"href": r"/extra%20database", "text": "extra database"}, + {"href": r"/extra-20database", "text": "extra database"}, {"href": "/fixtures", "text": "fixtures"}, ] == [{"href": a["href"], "text": a.text.strip()} for a in soup.select("h2 a")] # Database should show count text and attached tables @@ -44,8 +44,8 @@ def test_homepage(app_client_two_attached_databases): {"href": a["href"], "text": a.text.strip()} for a in links_p.findAll("a") ] assert [ - {"href": r"/extra%20database/searchable", "text": "searchable"}, - {"href": r"/extra%20database/searchable_view", "text": "searchable_view"}, + {"href": r"/extra-20database/searchable", "text": "searchable"}, + {"href": r"/extra-20database/searchable_view", "text": "searchable_view"}, ] == table_links @@ -140,7 +140,7 @@ def test_database_page(app_client): assert queries_ul is not None assert [ ( - "/fixtures/%F0%9D%90%9C%F0%9D%90%A2%F0%9D%90%AD%F0%9D%90%A2%F0%9D%90%9E%F0%9D%90%AC", + "/fixtures/-F0-9D-90-9C-F0-9D-90-A2-F0-9D-90-AD-F0-9D-90-A2-F0-9D-90-9E-F0-9D-90-AC", "𝐜𝐢𝐭𝐢𝐞𝐬", ), ("/fixtures/from_async_hook", "from_async_hook"), @@ -193,11 +193,11 @@ def test_row_redirects_with_url_hash(app_client_with_hash): def test_row_strange_table_name_with_url_hash(app_client_with_hash): - response = app_client_with_hash.get("/fixtures/table%2Fwith%2Fslashes.csv/3") + response = app_client_with_hash.get("/fixtures/table-2Fwith-2Fslashes-2Ecsv/3") assert response.status == 302 - assert response.headers["Location"].endswith("/table%2Fwith%2Fslashes.csv/3") + assert response.headers["Location"].endswith("/table-2Fwith-2Fslashes-2Ecsv/3") response = app_client_with_hash.get( - "/fixtures/table%2Fwith%2Fslashes.csv/3", follow_redirects=True + "/fixtures/table-2Fwith-2Fslashes-2Ecsv/3", follow_redirects=True ) assert response.status == 200 @@ -345,20 +345,38 @@ def test_row_links_from_other_tables(app_client, path, expected_text, expected_l assert link == expected_link -def test_row_html_compound_primary_key(app_client): - response = app_client.get("/fixtures/compound_primary_key/a,b") +@pytest.mark.parametrize( + "path,expected", + ( + ( + "/fixtures/compound_primary_key/a,b", + [ + [ + 'a', + 'b', + 'c', + ] + ], + ), + ( + "/fixtures/compound_primary_key/a-2Fb,-2Ec-2Dd", + [ + [ + 'a/b', + '.c-d', + 'c', + ] + ], + ), + ), +) +def test_row_html_compound_primary_key(app_client, path, expected): + response = app_client.get(path) assert response.status == 200 table = Soup(response.body, "html.parser").find("table") assert ["pk1", "pk2", "content"] == [ th.string.strip() for th in table.select("thead th") ] - expected = [ - [ - 'a', - 'b', - 'c', - ] - ] assert expected == [ [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") ] @@ -934,3 +952,9 @@ def test_no_alternate_url_json(app_client, path): assert ( 'a', 'b', 'c', - ] + ], + [ + 'a/b,.c-d', + 'a/b', + '.c-d', + 'c', + ], ] - assert expected == [ + assert [ [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] + ] == expected def test_table_html_foreign_key_links(app_client): diff --git a/tests/test_utils.py b/tests/test_utils.py index e7d67045..1c3ab495 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -93,7 +93,7 @@ def test_path_with_replaced_args(path, args, expected): "row,pks,expected_path", [ ({"A": "foo", "B": "bar"}, ["A", "B"], "foo,bar"), - ({"A": "f,o", "B": "bar"}, ["A", "B"], "f%2Co,bar"), + ({"A": "f,o", "B": "bar"}, ["A", "B"], "f-2Co,bar"), ({"A": 123}, ["A"], "123"), ( utils.CustomRow( @@ -646,3 +646,21 @@ async def test_derive_named_parameters(sql, expected): db = ds.get_database("_memory") params = await utils.derive_named_parameters(db, sql) assert params == expected + + +@pytest.mark.parametrize( + "original,expected", + ( + ("abc", "abc"), + ("/foo/bar", "-2Ffoo-2Fbar"), + ("/-/bar", "-2F-2D-2Fbar"), + ("-/db-/table.csv", "-2D-2Fdb-2D-2Ftable-2Ecsv"), + (r"%~-/", "-25-7E-2D-2F"), + ("-25-7E-2D-2F", "-2D25-2D7E-2D2D-2D2F"), + ), +) +def test_dash_encoding(original, expected): + actual = utils.dash_encode(original) + assert actual == expected + # And test round-trip + assert original == utils.dash_decode(actual)