ignore and replace options for bulk inserts, refs #1873

Also removed the rule that you cannot include primary keys in the rows you insert.

And added validation that catches invalid parameters in the incoming JSON.

And renamed "inserted" to "rows" in the returned JSON for return_rows: true
This commit is contained in:
Simon Willison 2022-11-01 11:07:59 -07:00
commit 9bec7c38eb
3 changed files with 111 additions and 17 deletions

View file

@ -1107,6 +1107,7 @@ class TableInsertView(BaseView):
if not isinstance(data, dict):
return _errors(["JSON must be a dictionary"])
keys = data.keys()
# keys must contain "row" or "rows"
if "row" not in keys and "rows" not in keys:
return _errors(['JSON must have one or other of "row" or "rows"'])
@ -1126,19 +1127,31 @@ class TableInsertView(BaseView):
for row in rows:
if not isinstance(row, dict):
return _errors(['"rows" must be a list of dictionaries'])
# Does this exceed max_insert_rows?
max_insert_rows = self.ds.setting("max_insert_rows")
if len(rows) > max_insert_rows:
return _errors(
["Too many rows, maximum allowed is {}".format(max_insert_rows)]
)
# Validate other parameters
extras = {
key: value for key, value in data.items() if key not in ("row", "rows")
}
valid_extras = {"return_rows", "ignore", "replace"}
invalid_extras = extras.keys() - valid_extras
if invalid_extras:
return _errors(
['Invalid parameter: "{}"'.format('", "'.join(sorted(invalid_extras)))]
)
if extras.get("ignore") and extras.get("replace"):
return _errors(['Cannot use "ignore" and "replace" at the same time'])
# Validate columns of each row
columns = await db.table_columns(table_name)
# TODO: There are cases where pks are OK, if not using auto-incrementing pk
pks = await db.primary_keys(table_name)
allowed_columns = set(columns) - set(pks)
columns = set(await db.table_columns(table_name))
for i, row in enumerate(rows):
invalid_columns = set(row.keys()) - allowed_columns
invalid_columns = set(row.keys()) - columns
if invalid_columns:
errors.append(
"Row {} has invalid columns: {}".format(
@ -1147,8 +1160,7 @@ class TableInsertView(BaseView):
)
if errors:
return _errors(errors)
extra = {key: data[key] for key in data if key not in ("rows", "row")}
return rows, errors, extra
return rows, errors, extras
async def post(self, request):
database_route = tilde_decode(request.url_vars["database"])
@ -1168,18 +1180,23 @@ class TableInsertView(BaseView):
request.actor, "insert-row", resource=(database_name, table_name)
):
return _error(["Permission denied"], 403)
rows, errors, extra = await self._validate_data(request, db, table_name)
rows, errors, extras = await self._validate_data(request, db, table_name)
if errors:
return _error(errors, 400)
should_return = bool(extra.get("return_rows", False))
ignore = extras.get("ignore")
replace = extras.get("replace")
should_return = bool(extras.get("return_rows", False))
# Insert rows
def insert_rows(conn):
table = sqlite_utils.Database(conn)[table_name]
if should_return:
rowids = []
for row in rows:
rowids.append(table.insert(row).last_rowid)
rowids.append(
table.insert(row, ignore=ignore, replace=replace).last_rowid
)
return list(
table.rows_where(
"rowid in ({})".format(",".join("?" for _ in rowids)),
@ -1187,12 +1204,12 @@ class TableInsertView(BaseView):
)
)
else:
table.insert_all(rows)
table.insert_all(rows, ignore=ignore, replace=replace)
rows = await db.execute_write_fn(insert_rows)
result = {"ok": True}
if should_return:
result["inserted"] = rows
result["rows"] = rows
return Response.json(result, status=201)