From 93ababe6f7150454d2cf278dae08569e505d2a5b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 2 Dec 2022 22:57:57 -0800 Subject: [PATCH] Initial attempt at insert/replace for /-/create, refs #1927 --- datasette/views/database.py | 139 ++++-------------------------------- tests/test_api_write.py | 96 +++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 124 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index 0d03d1f9..f5f329d8 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -560,7 +560,7 @@ class MagicParameters(dict): class TableCreateView(BaseView): name = "table-create" - _valid_keys = {"table", "rows", "row", "columns", "pk"} + _valid_keys = {"table", "rows", "row", "columns", "pk", "pks", "ignore", "replace"} _supported_column_types = { "text", "integer", @@ -596,130 +596,17 @@ class TableCreateView(BaseView): if invalid_keys: return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) - table_name = data.get("table") - if not table_name: - return _error(["Table is required"]) + # ignore and replace are mutually exclusive + if data.get("ignore") and data.get("replace"): + return _error(["ignore and replace are mutually exclusive"]) - if not self._table_name_re.match(table_name): - return _error(["Invalid table name"]) + # ignore and replace only allowed with row or rows + if "ignore" in data or "replace" in data: + if not data.get("row") and not data.get("rows"): + return _error(["ignore and replace require row or rows"]) - columns = data.get("columns") - rows = data.get("rows") - row = data.get("row") - if not columns and not rows and not row: - return _error(["columns, rows or row is required"]) - - if rows and row: - return _error(["Cannot specify both rows and row"]) - - if columns: - if rows or row: - return _error(["Cannot specify columns with rows or row"]) - if not isinstance(columns, list): - return _error(["columns must be a list"]) - for column in columns: - if not isinstance(column, dict): - return _error(["columns must be a list of objects"]) - if not column.get("name") or not isinstance(column.get("name"), str): - return _error(["Column name is required"]) - if not column.get("type"): - column["type"] = "text" - if column["type"] not in self._supported_column_types: - return _error( - ["Unsupported column type: {}".format(column["type"])] - ) - # No duplicate column names - dupes = {c["name"] for c in columns if columns.count(c) > 1} - if dupes: - return _error(["Duplicate column name: {}".format(", ".join(dupes))]) - - if row: - rows = [row] - - if rows: - if not isinstance(rows, list): - return _error(["rows must be a list"]) - for row in rows: - if not isinstance(row, dict): - return _error(["rows must be a list of objects"]) - - pk = data.get("pk") - if pk: - if not isinstance(pk, str): - return _error(["pk must be a string"]) - - def create_table(conn): - table = sqlite_utils.Database(conn)[table_name] - if rows: - table.insert_all(rows, pk=pk) - else: - table.create( - {c["name"]: c["type"] for c in columns}, - pk=pk, - ) - return table.schema - - try: - schema = await db.execute_write_fn(create_table) - except Exception as e: - return _error([str(e)]) - table_url = self.ds.absolute_url( - request, self.ds.urls.table(db.name, table_name) - ) - table_api_url = self.ds.absolute_url( - request, self.ds.urls.table(db.name, table_name, format="json") - ) - details = { - "ok": True, - "database": db.name, - "table": table_name, - "table_url": table_url, - "table_api_url": table_api_url, - "schema": schema, - } - if rows: - details["row_count"] = len(rows) - return Response.json(details, status=201) - - -class TableCreateView(BaseView): - name = "table-create" - - _valid_keys = {"table", "rows", "row", "columns", "pk", "pks"} - _supported_column_types = { - "text", - "integer", - "float", - "blob", - } - # Any string that does not contain a newline or start with sqlite_ - _table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$") - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - db = await self.ds.resolve_database(request) - database_name = db.name - - # Must have create-table permission - if not await self.ds.permission_allowed( - request.actor, "create-table", resource=database_name - ): - return _error(["Permission denied"], 403) - - body = await request.post_body() - try: - data = json.loads(body) - except json.JSONDecodeError as e: - return _error(["Invalid JSON: {}".format(e)]) - - if not isinstance(data, dict): - return _error(["JSON must be an object"]) - - invalid_keys = set(data.keys()) - self._valid_keys - if invalid_keys: - return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) + ignore = data.get("ignore") + replace = data.get("replace") table_name = data.get("table") if not table_name: @@ -783,10 +670,14 @@ class TableCreateView(BaseView): if not isinstance(pk, str): return _error(["pks must be a list of strings"]) + # If table exists already, read pks from that instead + if await db.table_exists(table_name): + pks = await db.primary_keys(table_name) + def create_table(conn): table = sqlite_utils.Database(conn)[table_name] if rows: - table.insert_all(rows, pk=pks or pk) + table.insert_all(rows, pk=pks or pk, ignore=ignore, replace=replace) else: table.create( {c["name"]: c["type"] for c in columns}, diff --git a/tests/test_api_write.py b/tests/test_api_write.py index b3b1def2..cfcf9db0 100644 --- a/tests/test_api_write.py +++ b/tests/test_api_write.py @@ -911,6 +911,34 @@ async def test_drop_table(ds_write, scenario): 400, {"ok": False, "errors": ["pks must be a list of strings"]}, ), + # Error: ignore and replace are mutually exclusive + ( + { + "table": "bad", + "row": {"id": 1, "name": "Row 1"}, + "pk": "id", + "ignore": True, + "replace": True, + }, + 400, + { + "ok": False, + "errors": ["ignore and replace are mutually exclusive"], + }, + ), + # ignore and replace require row or rows + ( + { + "table": "bad", + "columns": [{"name": "id", "type": "integer"}], + "ignore": True, + }, + 400, + { + "ok": False, + "errors": ["ignore and replace require row or rows"], + }, + ), ), ) async def test_create_table(ds_write, input, expected_status, expected_response): @@ -932,6 +960,74 @@ async def test_create_table(ds_write, input, expected_status, expected_response) assert data == expected_response +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input,expected_rows_after", + ( + ( + { + "table": "test_insert_replace", + "rows": [ + {"id": 1, "name": "Row 1 new"}, + {"id": 3, "name": "Row 3 new"}, + ], + "ignore": True, + }, + [ + {"id": 1, "name": "Row 1"}, + {"id": 2, "name": "Row 2"}, + {"id": 3, "name": "Row 3 new"}, + ], + ), + ( + { + "table": "test_insert_replace", + "rows": [ + {"id": 1, "name": "Row 1 new"}, + {"id": 3, "name": "Row 3 new"}, + ], + "replace": True, + }, + [ + {"id": 1, "name": "Row 1 new"}, + {"id": 2, "name": "Row 2"}, + {"id": 3, "name": "Row 3 new"}, + ], + ), + ), +) +async def test_create_table_ignore_replace(ds_write, input, expected_rows_after): + # Create table with two rows + token = write_token(ds_write) + first_response = await ds_write.client.post( + "/data/-/create", + json={ + "rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}], + "table": "test_insert_replace", + "pk": "id", + }, + headers={ + "Authorization": "Bearer {}".format(token), + "Content-Type": "application/json", + }, + ) + assert first_response.status_code == 201 + + # Try a second time + second_response = await ds_write.client.post( + "/data/-/create", + json=input, + headers={ + "Authorization": "Bearer {}".format(token), + "Content-Type": "application/json", + }, + ) + assert second_response.status_code == 201 + # Check that the rows are as expected + rows = await ds_write.client.get("/data/test_insert_replace.json?_shape=array") + assert rows.json() == expected_rows_after + + @pytest.mark.asyncio @pytest.mark.parametrize( "path",