From 3a79ad98eafb9da527a3b9d9d8fbeb81936b02e7 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 14 Jun 2018 23:51:23 -0700 Subject: [PATCH] Basic CSV export, refs #266 Tables and custom SQL query results can now be exported as CSV. The easiest way to do this is to use the .csv extension, e.g. /test_tables/facet_cities.csv By default this is served as Content-Type: text/plain so you can see it in your browser. If you want to download the file (using text/csv and with an appropriate Content-Disposition: attachment header) you can do so like this: /test_tables/facet_cities.csv?_dl=1 We link to the CSV and downloadable CSV URLs from the table and query pages. The links use ?_size=max and so by default will return 1,000 rows. Also fixes #303 - table names ending in .json or .csv are now detected and URLs are generated that look like this instead: /test_tables/table%2Fwith%2Fslashes.csv?_format=csv The ?_format= option is available for everything else too, but we link to the .csv / .json versions in most cases because they are aesthetically pleasing. --- datasette/app.py | 21 ++++---- datasette/templates/query.html | 2 +- datasette/templates/table.html | 2 +- datasette/utils.py | 44 ++++++++++++--- datasette/views/base.py | 97 +++++++++++++++++++++++++++++----- datasette/views/index.py | 4 +- datasette/views/special.py | 4 +- datasette/views/table.py | 2 - tests/test_api.py | 2 +- tests/test_csv.py | 37 +++++++++++++ tests/test_html.py | 15 ++++++ tests/test_utils.py | 51 ++++++++++++++++++ 12 files changed, 243 insertions(+), 38 deletions(-) create mode 100644 tests/test_csv.py diff --git a/datasette/app.py b/datasette/app.py index 3ea68c86..6cfc3666 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -224,6 +224,9 @@ class Datasette: conn.execute('PRAGMA cache_size=-{}'.format(self.config["cache_size_kb"])) pm.hook.prepare_connection(conn=conn) + def table_exists(self, database, table): + return table in self.inspect().get(database, {}).get("tables") + def inspect(self): " Inspect the database and return a dictionary of table metadata " if self._inspect: @@ -395,7 +398,7 @@ class Datasette: self.jinja_env.filters["escape_sqlite"] = escape_sqlite self.jinja_env.filters["to_css_class"] = to_css_class pm.hook.prepare_jinja2_environment(env=self.jinja_env) - app.add_route(IndexView.as_view(self), "/") + app.add_route(IndexView.as_view(self), "/") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires app.add_route(favicon, "/favicon.ico") app.static("/-/static/", str(app_root / "datasette" / "static")) @@ -408,37 +411,37 @@ class Datasette: app.static(modpath, plugin["static_path"]) app.add_route( JsonDataView.as_view(self, "inspect.json", self.inspect), - "/-/inspect", + "/-/inspect", ) app.add_route( JsonDataView.as_view(self, "metadata.json", lambda: self.metadata), - "/-/metadata", + "/-/metadata", ) app.add_route( JsonDataView.as_view(self, "versions.json", self.versions), - "/-/versions", + "/-/versions", ) app.add_route( JsonDataView.as_view(self, "plugins.json", self.plugins), - "/-/plugins", + "/-/plugins", ) app.add_route( JsonDataView.as_view(self, "config.json", lambda: self.config), - "/-/config", + "/-/config", ) app.add_route( - DatabaseView.as_view(self), "/" + DatabaseView.as_view(self), "/" ) app.add_route( DatabaseDownload.as_view(self), "/" ) app.add_route( TableView.as_view(self), - "//", + "//", ) app.add_route( RowView.as_view(self), - "///", + "///", ) self.register_custom_units() diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 78c0f48c..e04df160 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -40,7 +40,7 @@ {% if rows %} -

This data as .json

+ diff --git a/datasette/templates/table.html b/datasette/templates/table.html index d695da38..eda37bc7 100644 --- a/datasette/templates/table.html +++ b/datasette/templates/table.html @@ -92,7 +92,7 @@

View and edit SQL

{% endif %} -

This data as .json

+ {% if suggested_facets %}

diff --git a/datasette/utils.py b/datasette/utils.py index 0b4fce7d..eb31475d 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -225,14 +225,6 @@ def path_with_replaced_args(request, args, path=None): return path + query_string -def path_with_ext(request, ext): - path = request.path - path += ext - if request.query_string: - path += '?' + request.query_string - return path - - _css_re = re.compile(r'''['"\n\\]''') _boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') @@ -772,3 +764,39 @@ def get_plugins(pm): plugin_info['version'] = distinfo.version plugins.append(plugin_info) return plugins + + +FORMATS = ('csv', 'json', 'jsono') + + +def resolve_table_and_format(table_and_format, table_exists): + if '.' in table_and_format: + # Check if a table exists with this exact name + if table_exists(table_and_format): + return table_and_format, None + # Check if table ends with a known format + for _format in FORMATS: + if table_and_format.endswith(".{}".format(_format)): + table = table_and_format[:-(len(_format) + 1)] + return table, _format + return table_and_format, None + + +def path_with_format(request, format, extra_qs=None): + qs = extra_qs or {} + path = request.path + if "." in request.path: + qs["_format"] = format + else: + path = "{}.{}".format(path, format) + if qs: + extra = urllib.parse.urlencode(sorted(qs.items())) + if request.query_string: + path = "{}?{}&{}".format( + path, request.query_string, extra + ) + else: + path = "{}?{}".format(path, extra) + elif request.query_string: + path = "{}?{}".format(path, request.query_string) + return path diff --git a/datasette/views/base.py b/datasette/views/base.py index d8aa8a14..f44aa5ce 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,8 +1,10 @@ import asyncio +import csv import json import re import sqlite3 import time +import urllib import pint from sanic import response @@ -16,7 +18,8 @@ from datasette.utils import ( InvalidSql, path_from_row_pks, path_with_added_args, - path_with_ext, + path_with_format, + resolve_table_and_format, to_css_class ) @@ -113,13 +116,23 @@ class BaseView(RenderMixin): expected = info["hash"][:HASH_LENGTH] if expected != hash: + if "table_and_format" in kwargs: + table, _format = resolve_table_and_format( + table_and_format=urllib.parse.unquote_plus( + kwargs["table_and_format"] + ), + table_exists=lambda t: self.ds.table_exists(name, t) + ) + kwargs["table"] = table + if _format: + kwargs["as_format"] = ".{}".format(_format) should_redirect = "/{}-{}".format(name, expected) if "table" in kwargs: - should_redirect += "/" + kwargs["table"] + should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) if "pk_path" in kwargs: should_redirect += "/" + kwargs["pk_path"] - if "as_json" in kwargs: - should_redirect += kwargs["as_json"] + if "as_format" in kwargs: + should_redirect += kwargs["as_format"] if "as_db" in kwargs: should_redirect += kwargs["as_db"] return name, expected, should_redirect @@ -136,11 +149,65 @@ class BaseView(RenderMixin): return await self.view_get(request, name, hash, **kwargs) - async def view_get(self, request, name, hash, **kwargs): + async def as_csv(self, request, name, hash, **kwargs): try: - as_json = kwargs.pop("as_json") - except KeyError: - as_json = False + response_or_template_contexts = await self.data( + request, name, hash, **kwargs + ) + if isinstance(response_or_template_contexts, response.HTTPResponse): + return response_or_template_contexts + + else: + data, extra_template_data, templates = response_or_template_contexts + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + + except (sqlite3.OperationalError) as e: + raise DatasetteError(str(e)) + + except DatasetteError: + raise + # Convert rows and columns to CSV + async def stream_fn(r): + writer = csv.writer(r) + writer.writerow(data["columns"]) + for row in data["rows"]: + writer.writerow(row) + + content_type = "text/plain; charset=utf-8" + headers = {} + if request.args.get("_dl", None): + content_type = "text/csv; charset=utf-8" + disposition = 'attachment; filename="{}.csv"'.format( + kwargs.get('table', name) + ) + headers["Content-Disposition"] = disposition + + return response.stream( + stream_fn, + headers=headers, + content_type=content_type + ) + + async def view_get(self, request, name, hash, **kwargs): + # If ?_format= is provided, use that as the format + _format = request.args.get("_format", None) + if not _format: + _format = (kwargs.pop("as_format", None) or "").lstrip(".") + if "table_and_format" in kwargs: + table, _ext_format = resolve_table_and_format( + table_and_format=urllib.parse.unquote_plus( + kwargs["table_and_format"] + ), + table_exists=lambda t: self.ds.table_exists(name, t) + ) + _format = _format or _ext_format + kwargs["table"] = table + del kwargs["table_and_format"] + + if _format == "csv": + return await self.as_csv(request, name, hash, **kwargs) + extra_template_data = {} start = time.time() status_code = 200 @@ -175,9 +242,9 @@ class BaseView(RenderMixin): value = self.ds.metadata.get(key) if value: data[key] = value - if as_json: + if _format in ("json", "jsono"): # Special case for .jsono extension - redirect to _shape=objects - if as_json == ".jsono": + if _format == "jsono": return self.redirect( request, path_with_added_args( @@ -260,8 +327,14 @@ class BaseView(RenderMixin): **data, **extras, **{ - "url_json": path_with_ext(request, ".json"), - "url_jsono": path_with_ext(request, ".jsono"), + "url_json": path_with_format(request, "json"), + "url_csv": path_with_format(request, "csv", { + "_size": "max" + }), + "url_csv_dl": path_with_format(request, "csv", { + "_dl": "1", + "_size": "max" + }), "extra_css_urls": self.ds.extra_css_urls(), "extra_js_urls": self.ds.extra_js_urls(), "datasette_version": __version__, diff --git a/datasette/views/index.py b/datasette/views/index.py index c4ed3bef..66776e1c 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -16,7 +16,7 @@ class IndexView(RenderMixin): self.jinja_env = datasette.jinja_env self.executor = datasette.executor - async def get(self, request, as_json): + async def get(self, request, as_format): databases = [] for key, info in sorted(self.ds.inspect().items()): tables = [t for t in info["tables"].values() if not t["hidden"]] @@ -38,7 +38,7 @@ class IndexView(RenderMixin): "views_count": len(info["views"]), } databases.append(database) - if as_json: + if as_format: headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" diff --git a/datasette/views/special.py b/datasette/views/special.py index 986630fd..7fde5ee9 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -10,9 +10,9 @@ class JsonDataView(RenderMixin): self.filename = filename self.data_callback = data_callback - async def get(self, request, as_json): + async def get(self, request, as_format): data = self.data_callback() - if as_json: + if as_format: headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" diff --git a/datasette/views/table.py b/datasette/views/table.py index 12837acc..272902fb 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -232,7 +232,6 @@ class RowTableShared(BaseView): class TableView(RowTableShared): async def data(self, request, name, hash, table): - table = urllib.parse.unquote_plus(table) canned_query = self.ds.get_canned_query(name, table) if canned_query is not None: return await self.custom_sql( @@ -780,7 +779,6 @@ class TableView(RowTableShared): class RowView(RowTableShared): async def data(self, request, name, hash, table, pk_path): - table = urllib.parse.unquote_plus(table) pk_values = urlsafe_components(pk_path) info = self.ds.inspect()[name] table_info = info["tables"].get(table) or {} diff --git a/tests/test_api.py b/tests/test_api.py index 221cb0ca..af1fe4c5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -507,7 +507,7 @@ def test_table_shape_object_compound_primary_Key(app_client): def test_table_with_slashes_in_name(app_client): - response = app_client.get('/test_tables/table%2Fwith%2Fslashes.csv.json?_shape=objects') + response = app_client.get('/test_tables/table%2Fwith%2Fslashes.csv?_shape=objects&_format=json') assert response.status == 200 data = response.json assert data['rows'] == [{ diff --git a/tests/test_csv.py b/tests/test_csv.py new file mode 100644 index 00000000..b6e2269f --- /dev/null +++ b/tests/test_csv.py @@ -0,0 +1,37 @@ +from .fixtures import app_client # noqa + +EXPECTED_TABLE_CSV = '''id,content +1,hello +2,world +3, +'''.replace('\n', '\r\n') + +EXPECTED_CUSTOM_CSV = '''content +hello +world +"" +'''.replace('\n', '\r\n') + + +def test_table_csv(app_client): + response = app_client.get('/test_tables/simple_primary_key.csv') + assert response.status == 200 + assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] + assert EXPECTED_TABLE_CSV == response.text + + +def test_custom_sql_csv(app_client): + response = app_client.get( + '/test_tables.csv?sql=select+content+from+simple_primary_key' + ) + assert response.status == 200 + assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] + assert EXPECTED_CUSTOM_CSV == response.text + + +def test_table_csv_download(app_client): + response = app_client.get('/test_tables/simple_primary_key.csv?_dl=1') + assert response.status == 200 + assert 'text/csv; charset=utf-8' == response.headers['Content-Type'] + expected_disposition = 'attachment; filename="simple_primary_key.csv"' + assert expected_disposition == response.headers['Content-Disposition'] diff --git a/tests/test_html.py b/tests/test_html.py index 116945e1..c6ac54a5 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -274,6 +274,21 @@ def test_table_html_simple_primary_key(app_client): ] == [[str(td) for td in tr.select('td')] for tr in table.select('tbody tr')] +def test_table_csv_json_export_links(app_client): + response = app_client.get('/test_tables/simple_primary_key') + assert response.status == 200 + links = Soup(response.body, "html.parser").find("p", { + "class": "export-links" + }).findAll("a") + actual = [l["href"].split("/")[-1] for l in links] + expected = [ + "simple_primary_key.json", + "simple_primary_key.csv?_size=max", + "simple_primary_key.csv?_dl=1&_size=max" + ] + assert expected == actual + + def test_row_html_simple_primary_key(app_client): response = app_client.get('/test_tables/simple_primary_key/1') assert response.status == 200 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0572b5e5..d12bf927 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -299,3 +299,54 @@ def test_compound_keys_after_sql(): or (a = :p0 and b = :p1 and c > :p2)) '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) + + +def table_exists(table): + return table == "exists.csv" + + +@pytest.mark.parametrize( + "table_and_format,expected_table,expected_format", + [ + ("blah", "blah", None), + ("blah.csv", "blah", "csv"), + ("blah.json", "blah", "json"), + ("blah.baz", "blah.baz", None), + ("exists.csv", "exists.csv", None), + ], +) +def test_resolve_table_and_format( + table_and_format, expected_table, expected_format +): + actual_table, actual_format = utils.resolve_table_and_format( + table_and_format, table_exists + ) + assert expected_table == actual_table + assert expected_format == actual_format + + +@pytest.mark.parametrize( + "path,format,extra_qs,expected", + [ + ("/foo?sql=select+1", "csv", {}, "/foo.csv?sql=select+1"), + ("/foo?sql=select+1", "json", {}, "/foo.json?sql=select+1"), + ("/foo/bar", "json", {}, "/foo/bar.json"), + ("/foo/bar", "csv", {}, "/foo/bar.csv"), + ("/foo/bar.csv", "json", {}, "/foo/bar.csv?_format=json"), + ("/foo/bar", "csv", {"_dl": 1}, "/foo/bar.csv?_dl=1"), + ("/foo/b.csv", "json", {"_dl": 1}, "/foo/b.csv?_dl=1&_format=json"), + ( + "/sf-trees/Street_Tree_List?_search=cherry&_size=1000", + "csv", + {"_dl": 1}, + "/sf-trees/Street_Tree_List.csv?_search=cherry&_size=1000&_dl=1", + ), + ], +) +def test_path_with_format(path, format, extra_qs, expected): + request = Request( + path.encode('utf8'), + {}, '1.1', 'GET', None + ) + actual = utils.path_with_format(request, format, extra_qs) + assert expected == actual