diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index d326c773..d467383d 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1,7 +1,7 @@ import asyncio from contextlib import contextmanager import click -from collections import OrderedDict, namedtuple +from collections import OrderedDict, namedtuple, Counter import base64 import hashlib import inspect @@ -474,9 +474,25 @@ def get_outbound_foreign_keys(conn, table): if info is not None: id, seq, table_name, from_, to_, on_update, on_delete, match = info fks.append( - {"column": from_, "other_table": table_name, "other_column": to_} + { + "column": from_, + "other_table": table_name, + "other_column": to_, + "id": id, + "seq": seq, + } ) - return fks + # Filter out compound foreign keys by removing any where "id" is not unique + id_counts = Counter(fk["id"] for fk in fks) + return [ + { + "column": fk["column"], + "other_table": fk["other_table"], + "other_column": fk["other_column"], + } + for fk in fks + if id_counts[fk["id"]] == 1 + ] def get_all_foreign_keys(conn): @@ -487,20 +503,21 @@ def get_all_foreign_keys(conn): for table in tables: table_to_foreign_keys[table] = {"incoming": [], "outgoing": []} for table in tables: - infos = conn.execute(f"PRAGMA foreign_key_list([{table}])").fetchall() - for info in infos: - if info is not None: - id, seq, table_name, from_, to_, on_update, on_delete, match = info - if table_name not in table_to_foreign_keys: - # Weird edge case where something refers to a table that does - # not actually exist - continue - table_to_foreign_keys[table_name]["incoming"].append( - {"other_table": table, "column": to_, "other_column": from_} - ) - table_to_foreign_keys[table]["outgoing"].append( - {"other_table": table_name, "column": from_, "other_column": to_} - ) + fks = get_outbound_foreign_keys(conn, table) + for fk in fks: + table_name = fk["other_table"] + from_ = fk["column"] + to_ = fk["other_column"] + if table_name not in table_to_foreign_keys: + # Weird edge case where something refers to a table that does + # not actually exist + continue + table_to_foreign_keys[table_name]["incoming"].append( + {"other_table": table, "column": to_, "other_column": from_} + ) + table_to_foreign_keys[table]["outgoing"].append( + {"other_table": table_name, "column": from_, "other_column": to_} + ) return table_to_foreign_keys diff --git a/tests/fixtures.py b/tests/fixtures.py index 3abca821..f95a2d6b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -388,9 +388,12 @@ CREATE TABLE foreign_key_references ( foreign_key_with_label varchar(30), foreign_key_with_blank_label varchar(30), foreign_key_with_no_label varchar(30), + foreign_key_compound_pk1 varchar(30), + foreign_key_compound_pk2 varchar(30), FOREIGN KEY (foreign_key_with_label) REFERENCES simple_primary_key(id), FOREIGN KEY (foreign_key_with_blank_label) REFERENCES simple_primary_key(id), FOREIGN KEY (foreign_key_with_no_label) REFERENCES primary_key_multiple_columns(id) + FOREIGN KEY (foreign_key_compound_pk1, foreign_key_compound_pk2) REFERENCES compound_primary_key(pk1, pk2) ); CREATE TABLE sortable ( @@ -624,8 +627,8 @@ INSERT INTO simple_primary_key VALUES (4, 'RENDER_CELL_DEMO'); INSERT INTO primary_key_multiple_columns VALUES (1, 'hey', 'world'); INSERT INTO primary_key_multiple_columns_explicit_label VALUES (1, 'hey', 'world2'); -INSERT INTO foreign_key_references VALUES (1, 1, 3, 1); -INSERT INTO foreign_key_references VALUES (2, null, null, null); +INSERT INTO foreign_key_references VALUES (1, 1, 3, 1, 'a', 'b'); +INSERT INTO foreign_key_references VALUES (2, null, null, null, null, null); INSERT INTO complex_foreign_keys VALUES (1, 1, 2, 1); INSERT INTO custom_foreign_key_label VALUES (1, 1); diff --git a/tests/test_api.py b/tests/test_api.py index 2bab6c30..848daf9c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -237,6 +237,8 @@ def test_database_page(app_client): "foreign_key_with_label", "foreign_key_with_blank_label", "foreign_key_with_no_label", + "foreign_key_compound_pk1", + "foreign_key_compound_pk2", ], "primary_keys": ["pk"], "count": 2, @@ -1637,6 +1639,8 @@ def test_expand_label(app_client): "foreign_key_with_label": {"value": "1", "label": "hello"}, "foreign_key_with_blank_label": "3", "foreign_key_with_no_label": "1", + "foreign_key_compound_pk1": "a", + "foreign_key_compound_pk2": "b", } } @@ -1821,24 +1825,28 @@ def test_common_prefix_database_names(app_client_conflicting_database_names): assert db_name == data["database"] -def test_null_foreign_keys_are_not_expanded(app_client): +def test_null_and_compound_foreign_keys_are_not_expanded(app_client): response = app_client.get( "/fixtures/foreign_key_references.json?_shape=array&_labels=on" ) - assert [ + assert response.json == [ { "pk": "1", "foreign_key_with_label": {"value": "1", "label": "hello"}, "foreign_key_with_blank_label": {"value": "3", "label": ""}, "foreign_key_with_no_label": {"value": "1", "label": "1"}, + "foreign_key_compound_pk1": "a", + "foreign_key_compound_pk2": "b", }, { "pk": "2", "foreign_key_with_label": None, "foreign_key_with_blank_label": None, "foreign_key_with_no_label": None, + "foreign_key_compound_pk1": None, + "foreign_key_compound_pk2": None, }, - ] == response.json + ] def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): diff --git a/tests/test_csv.py b/tests/test_csv.py index 209bce2b..0fd665a9 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -42,9 +42,9 @@ pk,created,planet_int,on_earth,state,city_id,city_id_label,neighborhood,tags,com ) EXPECTED_TABLE_WITH_NULLABLE_LABELS_CSV = """ -pk,foreign_key_with_label,foreign_key_with_label_label,foreign_key_with_blank_label,foreign_key_with_blank_label_label,foreign_key_with_no_label,foreign_key_with_no_label_label -1,1,hello,3,,1,1 -2,,,,,, +pk,foreign_key_with_label,foreign_key_with_label_label,foreign_key_with_blank_label,foreign_key_with_blank_label_label,foreign_key_with_no_label,foreign_key_with_no_label_label,foreign_key_compound_pk1,foreign_key_compound_pk2 +1,1,hello,3,,1,1,a,b +2,,,,,,,, """.lstrip().replace( "\n", "\r\n" ) diff --git a/tests/test_html.py b/tests/test_html.py index d53dbabc..ecbf89b4 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -804,12 +804,16 @@ def test_table_html_foreign_key_links(app_client): '