API for bulk inserts, closes #1866

This commit is contained in:
Simon Willison 2022-10-29 23:03:45 -07:00
commit c35859ae3d
7 changed files with 320 additions and 51 deletions

View file

@ -99,6 +99,11 @@ SETTINGS = (
1000,
"Maximum rows that can be returned from a table or custom query",
),
Setting(
"max_insert_rows",
100,
"Maximum rows that can be inserted at a time using the bulk insert API",
),
Setting(
"num_sql_threads",
3,

View file

@ -30,6 +30,7 @@ from datasette.utils import (
)
from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response
from datasette.filters import Filters
import sqlite_utils
from .base import BaseView, DataView, DatasetteError, ureg
from .database import QueryView
@ -1085,62 +1086,109 @@ class TableInsertView(BaseView):
def __init__(self, datasette):
self.ds = datasette
async def _validate_data(self, request, db, table_name):
errors = []
def _errors(errors):
return None, errors, {}
if request.headers.get("content-type") != "application/json":
# TODO: handle form-encoded data
return _errors(["Invalid content-type, must be application/json"])
body = await request.post_body()
try:
data = json.loads(body)
except json.JSONDecodeError as e:
return _errors(["Invalid JSON: {}".format(e)])
if not isinstance(data, dict):
return _errors(["JSON must be a dictionary"])
keys = data.keys()
# keys must contain "row" or "rows"
if "row" not in keys and "rows" not in keys:
return _errors(['JSON must have one or other of "row" or "rows"'])
rows = []
if "row" in keys:
if "rows" in keys:
return _errors(['Cannot use "row" and "rows" at the same time'])
row = data["row"]
if not isinstance(row, dict):
return _errors(['"row" must be a dictionary'])
rows = [row]
data["return_rows"] = True
else:
rows = data["rows"]
if not isinstance(rows, list):
return _errors(['"rows" must be a list'])
for row in rows:
if not isinstance(row, dict):
return _errors(['"rows" must be a list of dictionaries'])
# Does this exceed max_insert_rows?
max_insert_rows = self.ds.setting("max_insert_rows")
if len(rows) > max_insert_rows:
return _errors(
["Too many rows, maximum allowed is {}".format(max_insert_rows)]
)
# Validate columns of each row
columns = await db.table_columns(table_name)
# TODO: There are cases where pks are OK, if not using auto-incrementing pk
pks = await db.primary_keys(table_name)
allowed_columns = set(columns) - set(pks)
for i, row in enumerate(rows):
invalid_columns = set(row.keys()) - allowed_columns
if invalid_columns:
errors.append(
"Row {} has invalid columns: {}".format(
i, ", ".join(sorted(invalid_columns))
)
)
if errors:
return _errors(errors)
extra = {key: data[key] for key in data if key not in ("rows", "row")}
return rows, errors, extra
async def post(self, request):
def _error(messages, status=400):
return Response.json({"ok": False, "errors": messages}, status=status)
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
return _error(["Database not found: {}".format(database_route)], 404)
database_name = db.name
table_name = tilde_decode(request.url_vars["table"])
# Table must exist (may handle table creation in the future)
db = self.ds.get_database(database_name)
if not await db.table_exists(table_name):
raise NotFound("Table not found: {}".format(table_name))
return _error(["Table not found: {}".format(table_name)], 404)
# Must have insert-row permission
if not await self.ds.permission_allowed(
request.actor, "insert-row", resource=(database_name, table_name)
):
raise Forbidden("Permission denied")
if request.headers.get("content-type") != "application/json":
# TODO: handle form-encoded data
raise BadRequest("Must send JSON data")
data = json.loads(await request.post_body())
if "row" not in data:
raise BadRequest('Must send a "row" key containing a dictionary')
row = data["row"]
if not isinstance(row, dict):
raise BadRequest("row must be a dictionary")
# Verify all columns exist
columns = await db.table_columns(table_name)
pks = await db.primary_keys(table_name)
for key in row:
if key not in columns:
raise BadRequest("Column not found: {}".format(key))
if key in pks:
raise BadRequest(
"Cannot insert into primary key column: {}".format(key)
return _error(["Permission denied"], 403)
rows, errors, extra = await self._validate_data(request, db, table_name)
if errors:
return _error(errors, 400)
should_return = bool(extra.get("return_rows", False))
# Insert rows
def insert_rows(conn):
table = sqlite_utils.Database(conn)[table_name]
if should_return:
rowids = []
for row in rows:
rowids.append(table.insert(row).last_rowid)
return list(
table.rows_where(
"rowid in ({})".format(",".join("?" for _ in rowids)), rowids
)
)
# Perform the insert
sql = "INSERT INTO [{table}] ({columns}) VALUES ({values})".format(
table=escape_sqlite(table_name),
columns=", ".join(escape_sqlite(c) for c in row),
values=", ".join("?" for c in row),
)
cursor = await db.execute_write(sql, list(row.values()))
# Return the new row
rowid = cursor.lastrowid
new_row = (
await db.execute(
"SELECT * FROM [{table}] WHERE rowid = ?".format(
table=escape_sqlite(table_name)
),
[rowid],
)
).first()
return Response.json(
{
"inserted": [dict(new_row)],
},
status=201,
)
else:
table.insert_all(rows)
rows = await db.execute_write_fn(insert_rows)
result = {"ok": True}
if should_return:
result["inserted"] = rows
return Response.json(result, status=201)