sortable_columns_for_table() no longer uses inspect()

Refs #420
This commit is contained in:
Simon Willison 2019-04-06 18:58:51 -07:00
commit 97331f3435
5 changed files with 31 additions and 11 deletions

View file

@ -33,6 +33,7 @@ from .utils import (
module_from_path, module_from_path,
sqlite3, sqlite3,
sqlite_timelimit, sqlite_timelimit,
table_columns,
to_css_class to_css_class
) )
from .inspect import inspect_hash, inspect_views, inspect_tables from .inspect import inspect_hash, inspect_views, inspect_tables
@ -463,6 +464,11 @@ class Datasette:
for p in ps for p in ps
] ]
async def table_columns(self, db_name, table):
return await self.execute_against_connection_in_thread(
db_name, lambda conn: table_columns(conn, table)
)
async def execute_against_connection_in_thread(self, db_name, fn): async def execute_against_connection_in_thread(self, db_name, fn):
def in_thread(): def in_thread():
conn = getattr(connections, db_name, None) conn = getattr(connections, db_name, None)

View file

@ -5,6 +5,7 @@ from .utils import (
detect_fts, detect_fts,
escape_sqlite, escape_sqlite,
get_all_foreign_keys, get_all_foreign_keys,
table_columns,
sqlite3 sqlite3
) )
@ -78,12 +79,7 @@ def inspect_tables(conn, database_metadata):
# e.g. "select count(*) from some_fts;" # e.g. "select count(*) from some_fts;"
count = 0 count = 0
column_names = [ column_names = table_columns(conn, table)
r[1]
for r in conn.execute(
"PRAGMA table_info({});".format(escape_sqlite(table))
).fetchall()
]
tables[table] = { tables[table] = {
"name": table, "name": table,

View file

@ -536,6 +536,15 @@ def detect_fts_sql(table):
'''.format(table=table) '''.format(table=table)
def table_columns(conn, table):
return [
r[1]
for r in conn.execute(
"PRAGMA table_info({});".format(escape_sqlite(table))
).fetchall()
]
class Filter: class Filter:
def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False):
self.key = key self.key = key

View file

@ -19,6 +19,7 @@ from datasette.utils import (
path_with_removed_args, path_with_removed_args,
path_with_replaced_args, path_with_replaced_args,
sqlite3, sqlite3,
table_columns,
to_css_class, to_css_class,
urlsafe_components, urlsafe_components,
value_as_boolean, value_as_boolean,
@ -31,13 +32,12 @@ LINK_WITH_VALUE = '<a href="/{database}/{table}/{link_id}">{id}</a>'
class RowTableShared(BaseView): class RowTableShared(BaseView):
def sortable_columns_for_table(self, database, table, use_rowid): async def sortable_columns_for_table(self, database, table, use_rowid):
table_metadata = self.table_metadata(database, table) table_metadata = self.table_metadata(database, table)
if "sortable_columns" in table_metadata: if "sortable_columns" in table_metadata:
sortable_columns = set(table_metadata["sortable_columns"]) sortable_columns = set(table_metadata["sortable_columns"])
else: else:
table_info = self.ds.inspect()[database]["tables"].get(table) or {} sortable_columns = set(await self.ds.table_columns(database, table))
sortable_columns = set(table_info.get("columns", []))
if use_rowid: if use_rowid:
sortable_columns.add("rowid") sortable_columns.add("rowid")
return sortable_columns return sortable_columns
@ -121,7 +121,7 @@ class RowTableShared(BaseView):
"Returns columns, rows for specified table - including fancy foreign key treatment" "Returns columns, rows for specified table - including fancy foreign key treatment"
table_metadata = self.table_metadata(database, table) table_metadata = self.table_metadata(database, table)
info = self.ds.inspect()[database] info = self.ds.inspect()[database]
sortable_columns = self.sortable_columns_for_table(database, table, True) sortable_columns = await self.sortable_columns_for_table(database, table, True)
columns = [ columns = [
{"name": r[0], "sortable": r[0] in sortable_columns} for r in description {"name": r[0], "sortable": r[0] in sortable_columns} for r in description
] ]
@ -363,7 +363,7 @@ class TableView(RowTableShared):
if not is_view: if not is_view:
table_rows_count = table_info["count"] table_rows_count = table_info["count"]
sortable_columns = self.sortable_columns_for_table(database, table, use_rowid) sortable_columns = await self.sortable_columns_for_table(database, table, use_rowid)
# Allow for custom sort order # Allow for custom sort order
sort = special_args.get("_sort") sort = special_args.get("_sort")

View file

@ -7,6 +7,7 @@ import json
import os import os
import pytest import pytest
from sanic.request import Request from sanic.request import Request
import sqlite3
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
@ -357,6 +358,14 @@ async def test_resolve_table_and_format(
assert expected_format == actual_format assert expected_format == actual_format
def test_table_columns():
conn = sqlite3.connect(":memory:")
conn.executescript("""
create table places (id integer primary key, name text, bob integer)
""")
assert ["id", "name", "bob"] == utils.table_columns(conn, "places")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"path,format,extra_qs,expected", "path,format,extra_qs,expected",
[ [