diff --git a/datasette/app.py b/datasette/app.py index 21d17194..29dd91af 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -33,6 +33,7 @@ from .utils import ( module_from_path, sqlite3, sqlite_timelimit, + table_columns, to_css_class ) from .inspect import inspect_hash, inspect_views, inspect_tables @@ -463,6 +464,11 @@ class Datasette: for p in ps ] + async def table_columns(self, db_name, table): + return await self.execute_against_connection_in_thread( + db_name, lambda conn: table_columns(conn, table) + ) + async def execute_against_connection_in_thread(self, db_name, fn): def in_thread(): conn = getattr(connections, db_name, None) diff --git a/datasette/inspect.py b/datasette/inspect.py index 42e7aeff..12f67f0a 100644 --- a/datasette/inspect.py +++ b/datasette/inspect.py @@ -5,6 +5,7 @@ from .utils import ( detect_fts, escape_sqlite, get_all_foreign_keys, + table_columns, sqlite3 ) @@ -78,12 +79,7 @@ def inspect_tables(conn, database_metadata): # e.g. "select count(*) from some_fts;" count = 0 - column_names = [ - r[1] - for r in conn.execute( - "PRAGMA table_info({});".format(escape_sqlite(table)) - ).fetchall() - ] + column_names = table_columns(conn, table) tables[table] = { "name": table, diff --git a/datasette/utils.py b/datasette/utils.py index 3a7e90c4..98f70592 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -536,6 +536,15 @@ def detect_fts_sql(table): '''.format(table=table) +def table_columns(conn, table): + return [ + r[1] + for r in conn.execute( + "PRAGMA table_info({});".format(escape_sqlite(table)) + ).fetchall() + ] + + class Filter: def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): self.key = key diff --git a/datasette/views/table.py b/datasette/views/table.py index 14f3be6f..84ebec05 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -19,6 +19,7 @@ from datasette.utils import ( path_with_removed_args, path_with_replaced_args, sqlite3, + table_columns, to_css_class, urlsafe_components, value_as_boolean, @@ -31,13 +32,12 @@ LINK_WITH_VALUE = '{id}' class RowTableShared(BaseView): - def sortable_columns_for_table(self, database, table, use_rowid): + async def sortable_columns_for_table(self, database, table, use_rowid): table_metadata = self.table_metadata(database, table) if "sortable_columns" in table_metadata: sortable_columns = set(table_metadata["sortable_columns"]) else: - table_info = self.ds.inspect()[database]["tables"].get(table) or {} - sortable_columns = set(table_info.get("columns", [])) + sortable_columns = set(await self.ds.table_columns(database, table)) if use_rowid: sortable_columns.add("rowid") return sortable_columns @@ -121,7 +121,7 @@ class RowTableShared(BaseView): "Returns columns, rows for specified table - including fancy foreign key treatment" table_metadata = self.table_metadata(database, table) info = self.ds.inspect()[database] - sortable_columns = self.sortable_columns_for_table(database, table, True) + sortable_columns = await self.sortable_columns_for_table(database, table, True) columns = [ {"name": r[0], "sortable": r[0] in sortable_columns} for r in description ] @@ -363,7 +363,7 @@ class TableView(RowTableShared): if not is_view: table_rows_count = table_info["count"] - sortable_columns = self.sortable_columns_for_table(database, table, use_rowid) + sortable_columns = await self.sortable_columns_for_table(database, table, use_rowid) # Allow for custom sort order sort = special_args.get("_sort") diff --git a/tests/test_utils.py b/tests/test_utils.py index fad1ac84..9a00b4b4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,7 @@ import json import os import pytest from sanic.request import Request +import sqlite3 import tempfile from unittest.mock import patch @@ -357,6 +358,14 @@ async def test_resolve_table_and_format( assert expected_format == actual_format +def test_table_columns(): + conn = sqlite3.connect(":memory:") + conn.executescript(""" + create table places (id integer primary key, name text, bob integer) + """) + assert ["id", "name", "bob"] == utils.table_columns(conn, "places") + + @pytest.mark.parametrize( "path,format,extra_qs,expected", [