From 9766a9c0876351a5c60430c4582a7686cb24ad79 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 17 Jun 2026 12:38:51 -0700 Subject: [PATCH] Add foreign keys to create table API - Add fk_table and optional fk_column support to create-table columns. - Validate create-table requests with Pydantic while preserving existing errors. - Document the API and cover inferred primary-key and validation cases. Refs https://github.com/simonw/datasette/pull/2789#issuecomment-4733544452 --- datasette/views/table_create_alter.py | 287 ++++++++++++++++---------- docs/json_api.rst | 25 +++ tests/test_api_write.py | 115 +++++++++++ 3 files changed, 321 insertions(+), 106 deletions(-) diff --git a/datasette/views/table_create_alter.py b/datasette/views/table_create_alter.py index 7decfad2..e8264a6f 100644 --- a/datasette/views/table_create_alter.py +++ b/datasette/views/table_create_alter.py @@ -2,7 +2,15 @@ import json import re from typing import Annotated, Any, Literal, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) +from pydantic_core import PydanticCustomError import sqlite_utils from sqlite_utils.db import DEFAULT as SQLITE_UTILS_DEFAULT @@ -25,6 +33,7 @@ CREATE_TABLE_TYPE_FOR_SQLITE_TYPE = { sqlite_type: column_type for column_type, sqlite_type in CREATE_TABLE_SQLITE_TYPES.items() } +TABLE_NAME_RE = re.compile(r"^(?!sqlite_)[^\n]+$") ALTER_TABLE_COLUMN_TYPES = CREATE_TABLE_COLUMN_TYPES ALTER_TABLE_TYPE_FOR_SQLITE_TYPE = { SQLiteType.TEXT: "text", @@ -98,6 +107,137 @@ class _StrictPydanticModel(BaseModel): model_config = ConfigDict(extra="forbid") +class CreateTableColumn(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: Any = None + type: Any = "text" + fk_table: str | None = None + fk_column: str | None = None + + @model_validator(mode="after") + def validate_column(self): + if not self.name or not isinstance(self.name, str): + raise PydanticCustomError("create_table", "Column name is required") + if not self.type: + self.type = "text" + elif self.type not in CREATE_TABLE_COLUMN_TYPES: + raise PydanticCustomError( + "create_table", "Unsupported column type: {type}", {"type": self.type} + ) + if self.fk_column and not self.fk_table: + raise PydanticCustomError( + "create_table_with_location", + "fk_column requires fk_table", + ) + return self + + +class CreateTableRequest(_StrictPydanticModel): + table: Any = None + rows: Any = None + row: Any = None + columns: list[CreateTableColumn] | None = None + pk: Any = None + pks: Any = None + ignore: bool | None = None + replace: bool | None = None + alter: bool | None = None + + @field_validator("columns", mode="before") + @classmethod + def validate_columns_list(cls, value): + if value is None: + return value + if not isinstance(value, list): + raise PydanticCustomError("create_table", "columns must be a list") + if not all(isinstance(column, dict) for column in value): + raise PydanticCustomError( + "create_table", "columns must be a list of objects" + ) + return value + + @model_validator(mode="after") + def validate_request(self): + if not self.table: + raise PydanticCustomError("create_table", "Table is required") + if not isinstance(self.table, str) or not TABLE_NAME_RE.match(self.table): + raise PydanticCustomError("create_table", "Invalid table name") + if not self.columns and not self.rows and not self.row: + raise PydanticCustomError( + "create_table", "columns, rows or row is required" + ) + if self.rows and self.row: + raise PydanticCustomError( + "create_table", "Cannot specify both rows and row" + ) + if self.columns and (self.rows or self.row): + raise PydanticCustomError( + "create_table", "Cannot specify columns with rows or row" + ) + if self.columns is not None: + seen = set() + duplicates = [] + for column in self.columns: + if column.name in seen and column.name not in duplicates: + duplicates.append(column.name) + seen.add(column.name) + if duplicates: + raise PydanticCustomError( + "create_table", + "Duplicate column name: {names}", + {"names": ", ".join(duplicates)}, + ) + if self.rows is not None: + if not isinstance(self.rows, list): + raise PydanticCustomError("create_table", "rows must be a list") + if not all(isinstance(row, dict) for row in self.rows): + raise PydanticCustomError( + "create_table", "rows must be a list of objects" + ) + if self.pk is not None and not isinstance(self.pk, str): + raise PydanticCustomError("create_table", "pk must be a string") + if self.pk and self.pks: + raise PydanticCustomError("create_table", "Cannot specify both pk and pks") + if self.pks is not None: + if not isinstance(self.pks, list): + raise PydanticCustomError("create_table", "pks must be a list") + if not all(isinstance(pk, str) for pk in self.pks): + raise PydanticCustomError( + "create_table", "pks must be a list of strings" + ) + if self.ignore and self.replace: + raise PydanticCustomError( + "create_table", "ignore and replace are mutually exclusive" + ) + if {"ignore", "replace"} & self.model_fields_set: + if not self.row and not self.rows: + raise PydanticCustomError( + "create_table", "ignore and replace require row or rows" + ) + if not self.pk and not self.pks: + raise PydanticCustomError( + "create_table", "ignore and replace require pk or pks" + ) + return self + + @property + def rows_list(self): + return [self.row] if self.row else self.rows + + @property + def foreign_keys(self): + if not self.columns: + return None + foreign_keys = [] + for column in self.columns: + if column.fk_table and column.fk_column: + foreign_keys.append((column.name, column.fk_table, column.fk_column)) + elif column.fk_table: + foreign_keys.append((column.name, column.fk_table)) + return foreign_keys or None + + class _DefaultArgsMixin(_StrictPydanticModel): default: Any | None = None default_expr: DefaultExpr | None = None @@ -209,6 +349,27 @@ def _pydantic_errors(validation_error): return errors +def _create_table_pydantic_errors(validation_error): + errors = validation_error.errors() + invalid_keys = sorted( + str(error["loc"][0]) + for error in errors + if error["type"] == "extra_forbidden" and len(error["loc"]) == 1 + ) + if invalid_keys: + return ["Invalid keys: {}".format(", ".join(invalid_keys))] + + output = [] + for error in errors: + message = error["msg"] + if error["type"] == "create_table": + output.append(message) + continue + location = ".".join(str(item) for item in error["loc"]) + output.append("{}: {}".format(location, message) if location else message) + return output + + def _table_schema_from_conn(conn, table_name): row = conn.execute( "select sql from sqlite_master where type = 'table' and name = ?", @@ -236,21 +397,6 @@ def _literal_default(db, value): class TableCreateView(BaseView): name = "table-create" - _valid_keys = { - "table", - "rows", - "row", - "columns", - "pk", - "pks", - "ignore", - "replace", - "alter", - } - _supported_column_types = set(CREATE_TABLE_COLUMN_TYPES) - # Any string that does not contain a newline or start with sqlite_ - _table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$") - def __init__(self, datasette): self.ds = datasette @@ -274,26 +420,13 @@ class TableCreateView(BaseView): if not isinstance(data, dict): return _error(["JSON must be an object"]) - invalid_keys = set(data.keys()) - self._valid_keys - if invalid_keys: - return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) + try: + create_request = CreateTableRequest.model_validate(data) + except ValidationError as e: + return _error(_create_table_pydantic_errors(e)) - # ignore and replace are mutually exclusive - if data.get("ignore") and data.get("replace"): - return _error(["ignore and replace are mutually exclusive"]) - - # ignore and replace only allowed with row or rows - if "ignore" in data or "replace" in data: - if not data.get("row") and not data.get("rows"): - return _error(["ignore and replace require row or rows"]) - - # ignore and replace require pk or pks - if "ignore" in data or "replace" in data: - if not data.get("pk") and not data.get("pks"): - return _error(["ignore and replace require pk or pks"]) - - ignore = data.get("ignore") - replace = data.get("replace") + ignore = create_request.ignore + replace = create_request.replace if replace: # Must have update-row permission @@ -304,24 +437,12 @@ class TableCreateView(BaseView): ): return _error(["Permission denied: need update-row"], 403) - table_name = data.get("table") - if not table_name: - return _error(["Table is required"]) + table_name = create_request.table + table_exists = await db.table_exists(table_name) + columns = create_request.columns + rows = create_request.rows_list - if not self._table_name_re.match(table_name): - return _error(["Invalid table name"]) - - table_exists = await db.table_exists(data["table"]) - columns = data.get("columns") - rows = data.get("rows") - row = data.get("row") - if not columns and not rows and not row: - return _error(["columns, rows or row is required"]) - - if rows and row: - return _error(["Cannot specify both rows and row"]) - - if rows or row: + if rows: # Must have insert-row permission if not await self.ds.allowed( action="insert-row", @@ -331,13 +452,13 @@ class TableCreateView(BaseView): return _error(["Permission denied: need insert-row"], 403) alter = False - if rows or row: + if rows: if not table_exists: # if table is being created for the first time, alter=True alter = True else: # alter=True only if they request it AND they have permission - if data.get("alter"): + if create_request.alter: if not await self.ds.allowed( action="alter-table", resource=DatabaseResource(database=database_name), @@ -346,64 +467,17 @@ class TableCreateView(BaseView): return _error(["Permission denied: need alter-table"], 403) alter = True - if columns: - if rows or row: - return _error(["Cannot specify columns with rows or row"]) - if not isinstance(columns, list): - return _error(["columns must be a list"]) - for column in columns: - if not isinstance(column, dict): - return _error(["columns must be a list of objects"]) - if not column.get("name") or not isinstance(column.get("name"), str): - return _error(["Column name is required"]) - if not column.get("type"): - column["type"] = "text" - if column["type"] not in self._supported_column_types: - return _error( - ["Unsupported column type: {}".format(column["type"])] - ) - # No duplicate column names - dupes = {c["name"] for c in columns if columns.count(c) > 1} - if dupes: - return _error(["Duplicate column name: {}".format(", ".join(dupes))]) - - if row: - rows = [row] - - if rows: - if not isinstance(rows, list): - return _error(["rows must be a list"]) - for row in rows: - if not isinstance(row, dict): - return _error(["rows must be a list of objects"]) - - pk = data.get("pk") - pks = data.get("pks") - - if pk and pks: - return _error(["Cannot specify both pk and pks"]) - if pk: - if not isinstance(pk, str): - return _error(["pk must be a string"]) - if pks: - if not isinstance(pks, list): - return _error(["pks must be a list"]) - for pk in pks: - if not isinstance(pk, str): - return _error(["pks must be a list of strings"]) + pk = create_request.pk + pks = create_request.pks # If table exists already, read pks from that instead if table_exists: actual_pks = await db.primary_keys(table_name) # if pk passed and table already exists check it does not change bad_pks = False - if len(actual_pks) == 1 and data.get("pk") and data["pk"] != actual_pks[0]: + if len(actual_pks) == 1 and pk and pk != actual_pks[0]: bad_pks = True - elif ( - len(actual_pks) > 1 - and data.get("pks") - and set(data["pks"]) != set(actual_pks) - ): + elif len(actual_pks) > 1 and pks and set(pks) != set(actual_pks): bad_pks = True if bad_pks: return _error(["pk cannot be changed for existing table"]) @@ -423,8 +497,9 @@ class TableCreateView(BaseView): ) else: table.create( - {c["name"]: c["type"] for c in columns}, + {column.name: column.type for column in columns}, pk=pks or pk, + foreign_keys=create_request.foreign_keys, ) return table.schema diff --git a/docs/json_api.rst b/docs/json_api.rst index 4074b479..1b4a196e 100644 --- a/docs/json_api.rst +++ b/docs/json_api.rst @@ -1981,6 +1981,7 @@ The JSON here describes the table that will be created: - ``name`` is the name of the column. This is required. - ``type`` is the type of the column. This is optional - if not provided, ``text`` will be assumed. The valid types are ``text``, ``integer``, ``float`` and ``blob``. + - ``fk_table`` can be used to create a single-column foreign key constraint referencing another table. ``fk_column`` is optional and can be used to specify the referenced column - if omitted, Datasette will use the single primary key of ``fk_table``. * ``pk`` is the primary key for the table. This is optional - if not provided, Datasette will create a SQLite table with a hidden ``rowid`` column. @@ -1993,6 +1994,30 @@ The JSON here describes the table that will be created: * ``replace`` can be set to ``true`` to replace existing rows by primary key if the table already exists. This requires the :ref:`actions_update_row` permission. * ``alter`` can be set to ``true`` if you want to automatically add any missing columns to the table. This requires the :ref:`actions_alter_table` permission. +This example creates a foreign key from ``projects.owner_id`` to the single primary key of ``owners``: + +.. code-block:: json + + { + "table": "projects", + "columns": [ + { + "name": "id", + "type": "integer" + }, + { + "name": "owner_id", + "type": "integer", + "fk_table": "owners" + }, + { + "name": "title", + "type": "text" + } + ], + "pk": "id" + } + If the table is successfully created this will return a ``201`` status code and the following response: .. code-block:: json diff --git a/tests/test_api_write.py b/tests/test_api_write.py index f117c06e..627b1ac1 100644 --- a/tests/test_api_write.py +++ b/tests/test_api_write.py @@ -1614,6 +1614,121 @@ async def test_create_table( assert [e.name for e in events] == expected_events +@pytest.mark.asyncio +async def test_create_table_with_foreign_key(ds_write): + token = write_token(ds_write) + response = await ds_write.client.post( + "/data/-/create", + json={ + "table": "owners", + "columns": [ + {"name": "id", "type": "integer"}, + {"name": "name", "type": "text"}, + ], + "pk": "id", + }, + headers=_headers(token), + ) + assert response.status_code == 201 + + response = await ds_write.client.post( + "/data/-/create", + json={ + "table": "projects", + "columns": [ + {"name": "id", "type": "integer"}, + { + "name": "owner_id", + "type": "integer", + "fk_table": "owners", + }, + {"name": "title", "type": "text"}, + ], + "pk": "id", + }, + headers=_headers(token), + ) + assert response.status_code == 201 + data = response.json() + assert "[owner_id] INTEGER REFERENCES [owners]([id])" in data["schema"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "column,expected_error", + ( + ( + {"name": "owner_id", "type": "integer", "fk_table": "owners"}, + None, + ), + ( + {"name": "owner_id", "type": "integer", "fk_column": "id"}, + "columns.0: fk_column requires fk_table", + ), + ), +) +async def test_create_table_foreign_key_validation(ds_write, column, expected_error): + token = write_token(ds_write) + response = await ds_write.client.post( + "/data/-/create", + json={ + "table": "projects", + "columns": [column], + }, + headers=_headers(token), + ) + if expected_error: + assert response.status_code == 400 + assert response.json() == {"ok": False, "errors": [expected_error]} + else: + assert response.status_code == 400 + assert response.json() == { + "ok": False, + "errors": ["Could not detect single primary key for table 'owners'"], + } + + +@pytest.mark.asyncio +async def test_create_table_foreign_key_without_fk_column_requires_single_pk(ds_write): + token = write_token(ds_write) + response = await ds_write.client.post( + "/data/-/create", + json={ + "table": "accounts", + "columns": [ + {"name": "tenant_id", "type": "integer"}, + {"name": "id", "type": "integer"}, + {"name": "name", "type": "text"}, + ], + "pks": ["tenant_id", "id"], + }, + headers=_headers(token), + ) + assert response.status_code == 201 + + response = await ds_write.client.post( + "/data/-/create", + json={ + "table": "projects", + "columns": [ + {"name": "id", "type": "integer"}, + { + "name": "account_id", + "type": "integer", + "fk_table": "accounts", + }, + ], + "pk": "id", + }, + headers=_headers(token), + ) + assert response.status_code == 400 + assert response.json() == { + "ok": False, + "errors": ["Could not detect single primary key for table 'accounts'"], + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "permissions,body,expected_status,expected_errors",