diff --git a/datasette/views/table.py b/datasette/views/table.py index 1e3d566e..7692a4e3 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -1107,6 +1107,7 @@ class TableInsertView(BaseView): if not isinstance(data, dict): return _errors(["JSON must be a dictionary"]) keys = data.keys() + # keys must contain "row" or "rows" if "row" not in keys and "rows" not in keys: return _errors(['JSON must have one or other of "row" or "rows"']) @@ -1126,19 +1127,31 @@ class TableInsertView(BaseView): for row in rows: if not isinstance(row, dict): return _errors(['"rows" must be a list of dictionaries']) + # Does this exceed max_insert_rows? max_insert_rows = self.ds.setting("max_insert_rows") if len(rows) > max_insert_rows: return _errors( ["Too many rows, maximum allowed is {}".format(max_insert_rows)] ) + + # Validate other parameters + extras = { + key: value for key, value in data.items() if key not in ("row", "rows") + } + valid_extras = {"return_rows", "ignore", "replace"} + invalid_extras = extras.keys() - valid_extras + if invalid_extras: + return _errors( + ['Invalid parameter: "{}"'.format('", "'.join(sorted(invalid_extras)))] + ) + if extras.get("ignore") and extras.get("replace"): + return _errors(['Cannot use "ignore" and "replace" at the same time']) + # Validate columns of each row - columns = await db.table_columns(table_name) - # TODO: There are cases where pks are OK, if not using auto-incrementing pk - pks = await db.primary_keys(table_name) - allowed_columns = set(columns) - set(pks) + columns = set(await db.table_columns(table_name)) for i, row in enumerate(rows): - invalid_columns = set(row.keys()) - allowed_columns + invalid_columns = set(row.keys()) - columns if invalid_columns: errors.append( "Row {} has invalid columns: {}".format( @@ -1147,8 +1160,7 @@ class TableInsertView(BaseView): ) if errors: return _errors(errors) - extra = {key: data[key] for key in data if key not in ("rows", "row")} - return rows, errors, extra + return rows, errors, extras async def post(self, request): database_route = tilde_decode(request.url_vars["database"]) @@ -1168,18 +1180,23 @@ class TableInsertView(BaseView): request.actor, "insert-row", resource=(database_name, table_name) ): return _error(["Permission denied"], 403) - rows, errors, extra = await self._validate_data(request, db, table_name) + rows, errors, extras = await self._validate_data(request, db, table_name) if errors: return _error(errors, 400) - should_return = bool(extra.get("return_rows", False)) + ignore = extras.get("ignore") + replace = extras.get("replace") + + should_return = bool(extras.get("return_rows", False)) # Insert rows def insert_rows(conn): table = sqlite_utils.Database(conn)[table_name] if should_return: rowids = [] for row in rows: - rowids.append(table.insert(row).last_rowid) + rowids.append( + table.insert(row, ignore=ignore, replace=replace).last_rowid + ) return list( table.rows_where( "rowid in ({})".format(",".join("?" for _ in rowids)), @@ -1187,12 +1204,12 @@ class TableInsertView(BaseView): ) ) else: - table.insert_all(rows) + table.insert_all(rows, ignore=ignore, replace=replace) rows = await db.execute_write_fn(insert_rows) result = {"ok": True} if should_return: - result["inserted"] = rows + result["rows"] = rows return Response.json(result, status=201) diff --git a/docs/json_api.rst b/docs/json_api.rst index da4500ab..34c13211 100644 --- a/docs/json_api.rst +++ b/docs/json_api.rst @@ -489,7 +489,7 @@ If successful, this will return a ``201`` status code and the newly inserted row .. code-block:: json { - "inserted": [ + "rows": [ { "id": 1, "column1": "value1", @@ -538,7 +538,7 @@ To return the newly inserted rows, add the ``"return_rows": true`` key to the re "return_rows": true } -This will return the same ``"inserted"`` key as the single row example above. There is a small performance penalty for using this option. +This will return the same ``"rows"`` key as the single row example above. There is a small performance penalty for using this option. .. _RowDeleteView: diff --git a/tests/test_api_write.py b/tests/test_api_write.py index 1cfba104..d0b0f324 100644 --- a/tests/test_api_write.py +++ b/tests/test_api_write.py @@ -37,7 +37,7 @@ async def test_write_row(ds_write): ) expected_row = {"id": 1, "title": "Test", "score": 1.0} assert response.status_code == 201 - assert response.json()["inserted"] == [expected_row] + assert response.json()["rows"] == [expected_row] rows = (await ds_write.get_database("data").execute("select * from docs")).rows assert dict(rows[0]) == expected_row @@ -70,7 +70,7 @@ async def test_write_rows(ds_write, return_rows): ] assert response.json()["ok"] is True if return_rows: - assert response.json()["inserted"] == actual_rows + assert response.json()["rows"] == actual_rows @pytest.mark.asyncio @@ -156,6 +156,27 @@ async def test_write_rows(ds_write, return_rows): 400, ["Too many rows, maximum allowed is 100"], ), + ( + "/data/docs/-/insert", + {"rows": [{"title": "Test"}], "ignore": True, "replace": True}, + None, + 400, + ['Cannot use "ignore" and "replace" at the same time'], + ), + ( + "/data/docs/-/insert", + {"rows": [{"title": "Test"}], "invalid_param": True}, + None, + 400, + ['Invalid parameter: "invalid_param"'], + ), + ( + "/data/docs/-/insert", + {"rows": [{"title": "Test"}], "one": True, "two": True}, + None, + 400, + ['Invalid parameter: "one", "two"'], + ), # Validate columns of each row ( "/data/docs/-/insert", @@ -196,6 +217,62 @@ async def test_write_row_errors( assert response.json()["errors"] == expected_errors +@pytest.mark.asyncio +@pytest.mark.parametrize( + "ignore,replace,expected_rows", + ( + ( + True, + False, + [ + {"id": 1, "title": "Exists", "score": None}, + ], + ), + ( + False, + True, + [ + {"id": 1, "title": "One", "score": None}, + ], + ), + ), +) +@pytest.mark.parametrize("should_return", (True, False)) +async def test_insert_ignore_replace( + ds_write, ignore, replace, expected_rows, should_return +): + await ds_write.get_database("data").execute_write( + "insert into docs (id, title) values (1, 'Exists')" + ) + token = write_token(ds_write) + data = {"rows": [{"id": 1, "title": "One"}]} + if ignore: + data["ignore"] = True + if replace: + data["replace"] = True + if should_return: + data["return_rows"] = True + response = await ds_write.client.post( + "/data/docs/-/insert", + json=data, + headers={ + "Authorization": "Bearer {}".format(token), + "Content-Type": "application/json", + }, + ) + assert response.status_code == 201 + actual_rows = [ + dict(r) + for r in ( + await ds_write.get_database("data").execute("select * from docs") + ).rows + ] + assert actual_rows == expected_rows + assert response.json()["ok"] is True + if should_return: + assert response.json()["rows"] == expected_rows + + @pytest.mark.asyncio @pytest.mark.parametrize("scenario", ("no_token", "no_perm", "bad_table", "has_perm")) async def test_delete_row(ds_write, scenario): @@ -217,7 +294,7 @@ async def test_delete_row(ds_write, scenario): }, ) assert insert_response.status_code == 201 - pk = insert_response.json()["inserted"][0]["id"] + pk = insert_response.json()["rows"][0]["id"] path = "/data/{}/{}/-/delete".format( "docs" if scenario != "bad_table" else "bad_table", pk