mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
API for bulk inserts, closes #1866
This commit is contained in:
parent
c9b5f5d598
commit
c35859ae3d
7 changed files with 320 additions and 51 deletions
|
|
@ -99,6 +99,11 @@ SETTINGS = (
|
|||
1000,
|
||||
"Maximum rows that can be returned from a table or custom query",
|
||||
),
|
||||
Setting(
|
||||
"max_insert_rows",
|
||||
100,
|
||||
"Maximum rows that can be inserted at a time using the bulk insert API",
|
||||
),
|
||||
Setting(
|
||||
"num_sql_threads",
|
||||
3,
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from datasette.utils import (
|
|||
)
|
||||
from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response
|
||||
from datasette.filters import Filters
|
||||
import sqlite_utils
|
||||
from .base import BaseView, DataView, DatasetteError, ureg
|
||||
from .database import QueryView
|
||||
|
||||
|
|
@ -1085,62 +1086,109 @@ class TableInsertView(BaseView):
|
|||
def __init__(self, datasette):
|
||||
self.ds = datasette
|
||||
|
||||
async def _validate_data(self, request, db, table_name):
|
||||
errors = []
|
||||
|
||||
def _errors(errors):
|
||||
return None, errors, {}
|
||||
|
||||
if request.headers.get("content-type") != "application/json":
|
||||
# TODO: handle form-encoded data
|
||||
return _errors(["Invalid content-type, must be application/json"])
|
||||
body = await request.post_body()
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError as e:
|
||||
return _errors(["Invalid JSON: {}".format(e)])
|
||||
if not isinstance(data, dict):
|
||||
return _errors(["JSON must be a dictionary"])
|
||||
keys = data.keys()
|
||||
# keys must contain "row" or "rows"
|
||||
if "row" not in keys and "rows" not in keys:
|
||||
return _errors(['JSON must have one or other of "row" or "rows"'])
|
||||
rows = []
|
||||
if "row" in keys:
|
||||
if "rows" in keys:
|
||||
return _errors(['Cannot use "row" and "rows" at the same time'])
|
||||
row = data["row"]
|
||||
if not isinstance(row, dict):
|
||||
return _errors(['"row" must be a dictionary'])
|
||||
rows = [row]
|
||||
data["return_rows"] = True
|
||||
else:
|
||||
rows = data["rows"]
|
||||
if not isinstance(rows, list):
|
||||
return _errors(['"rows" must be a list'])
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
return _errors(['"rows" must be a list of dictionaries'])
|
||||
# Does this exceed max_insert_rows?
|
||||
max_insert_rows = self.ds.setting("max_insert_rows")
|
||||
if len(rows) > max_insert_rows:
|
||||
return _errors(
|
||||
["Too many rows, maximum allowed is {}".format(max_insert_rows)]
|
||||
)
|
||||
# Validate columns of each row
|
||||
columns = await db.table_columns(table_name)
|
||||
# TODO: There are cases where pks are OK, if not using auto-incrementing pk
|
||||
pks = await db.primary_keys(table_name)
|
||||
allowed_columns = set(columns) - set(pks)
|
||||
for i, row in enumerate(rows):
|
||||
invalid_columns = set(row.keys()) - allowed_columns
|
||||
if invalid_columns:
|
||||
errors.append(
|
||||
"Row {} has invalid columns: {}".format(
|
||||
i, ", ".join(sorted(invalid_columns))
|
||||
)
|
||||
)
|
||||
if errors:
|
||||
return _errors(errors)
|
||||
extra = {key: data[key] for key in data if key not in ("rows", "row")}
|
||||
return rows, errors, extra
|
||||
|
||||
async def post(self, request):
|
||||
def _error(messages, status=400):
|
||||
return Response.json({"ok": False, "errors": messages}, status=status)
|
||||
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
return _error(["Database not found: {}".format(database_route)], 404)
|
||||
database_name = db.name
|
||||
table_name = tilde_decode(request.url_vars["table"])
|
||||
|
||||
# Table must exist (may handle table creation in the future)
|
||||
db = self.ds.get_database(database_name)
|
||||
if not await db.table_exists(table_name):
|
||||
raise NotFound("Table not found: {}".format(table_name))
|
||||
return _error(["Table not found: {}".format(table_name)], 404)
|
||||
# Must have insert-row permission
|
||||
if not await self.ds.permission_allowed(
|
||||
request.actor, "insert-row", resource=(database_name, table_name)
|
||||
):
|
||||
raise Forbidden("Permission denied")
|
||||
if request.headers.get("content-type") != "application/json":
|
||||
# TODO: handle form-encoded data
|
||||
raise BadRequest("Must send JSON data")
|
||||
data = json.loads(await request.post_body())
|
||||
if "row" not in data:
|
||||
raise BadRequest('Must send a "row" key containing a dictionary')
|
||||
row = data["row"]
|
||||
if not isinstance(row, dict):
|
||||
raise BadRequest("row must be a dictionary")
|
||||
# Verify all columns exist
|
||||
columns = await db.table_columns(table_name)
|
||||
pks = await db.primary_keys(table_name)
|
||||
for key in row:
|
||||
if key not in columns:
|
||||
raise BadRequest("Column not found: {}".format(key))
|
||||
if key in pks:
|
||||
raise BadRequest(
|
||||
"Cannot insert into primary key column: {}".format(key)
|
||||
return _error(["Permission denied"], 403)
|
||||
rows, errors, extra = await self._validate_data(request, db, table_name)
|
||||
if errors:
|
||||
return _error(errors, 400)
|
||||
|
||||
should_return = bool(extra.get("return_rows", False))
|
||||
# Insert rows
|
||||
def insert_rows(conn):
|
||||
table = sqlite_utils.Database(conn)[table_name]
|
||||
if should_return:
|
||||
rowids = []
|
||||
for row in rows:
|
||||
rowids.append(table.insert(row).last_rowid)
|
||||
return list(
|
||||
table.rows_where(
|
||||
"rowid in ({})".format(",".join("?" for _ in rowids)), rowids
|
||||
)
|
||||
)
|
||||
# Perform the insert
|
||||
sql = "INSERT INTO [{table}] ({columns}) VALUES ({values})".format(
|
||||
table=escape_sqlite(table_name),
|
||||
columns=", ".join(escape_sqlite(c) for c in row),
|
||||
values=", ".join("?" for c in row),
|
||||
)
|
||||
cursor = await db.execute_write(sql, list(row.values()))
|
||||
# Return the new row
|
||||
rowid = cursor.lastrowid
|
||||
new_row = (
|
||||
await db.execute(
|
||||
"SELECT * FROM [{table}] WHERE rowid = ?".format(
|
||||
table=escape_sqlite(table_name)
|
||||
),
|
||||
[rowid],
|
||||
)
|
||||
).first()
|
||||
return Response.json(
|
||||
{
|
||||
"inserted": [dict(new_row)],
|
||||
},
|
||||
status=201,
|
||||
)
|
||||
else:
|
||||
table.insert_all(rows)
|
||||
|
||||
rows = await db.execute_write_fn(insert_rows)
|
||||
result = {"ok": True}
|
||||
if should_return:
|
||||
result["inserted"] = rows
|
||||
return Response.json(result, status=201)
|
||||
|
|
|
|||
|
|
@ -213,6 +213,8 @@ These can be passed to ``datasette serve`` using ``datasette serve --setting nam
|
|||
(default=100)
|
||||
max_returned_rows Maximum rows that can be returned from a table or
|
||||
custom query (default=1000)
|
||||
max_insert_rows Maximum rows that can be inserted at a time using
|
||||
the bulk insert API (default=1000)
|
||||
num_sql_threads Number of threads in the thread pool for
|
||||
executing SQLite queries (default=3)
|
||||
sql_time_limit_ms Time limit for a SQL query in milliseconds
|
||||
|
|
|
|||
|
|
@ -465,11 +465,13 @@ Datasette provides a write API for JSON data. This is a POST-only API that requi
|
|||
|
||||
.. _TableInsertView:
|
||||
|
||||
Inserting a single row
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
Inserting rows
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
This requires the :ref:`permissions_insert_row` permission.
|
||||
|
||||
A single row can be inserted using the ``"row"`` key:
|
||||
|
||||
::
|
||||
|
||||
POST /<database>/<table>/-/insert
|
||||
|
|
@ -495,3 +497,45 @@ If successful, this will return a ``201`` status code and the newly inserted row
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
To insert multiple rows at a time, use the same API method but send a list of dictionaries as the ``"rows"`` key:
|
||||
|
||||
::
|
||||
|
||||
POST /<database>/<table>/-/insert
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer dstok_<rest-of-token>
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"column1": "value1",
|
||||
"column2": "value2"
|
||||
},
|
||||
{
|
||||
"column1": "value3",
|
||||
"column2": "value4"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
If successful, this will return a ``201`` status code and an empty ``{}`` response body.
|
||||
|
||||
To return the newly inserted rows, add the ``"return_rows": true`` key to the request body:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"column1": "value1",
|
||||
"column2": "value2"
|
||||
},
|
||||
{
|
||||
"column1": "value3",
|
||||
"column2": "value4"
|
||||
}
|
||||
],
|
||||
"return_rows": true
|
||||
}
|
||||
|
||||
This will return the same ``"inserted"`` key as the single row example above. There is a small performance penalty for using this option.
|
||||
|
|
|
|||
|
|
@ -96,6 +96,17 @@ You can increase or decrease this limit like so::
|
|||
|
||||
datasette mydatabase.db --setting max_returned_rows 2000
|
||||
|
||||
.. _setting_max_insert_rows:
|
||||
|
||||
max_insert_rows
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
Maximum rows that can be inserted at a time using the bulk insert API, see :ref:`TableInsertView`. Defaults to 100.
|
||||
|
||||
You can increase or decrease this limit like so::
|
||||
|
||||
datasette mydatabase.db --setting max_insert_rows 1000
|
||||
|
||||
.. _setting_num_sql_threads:
|
||||
|
||||
num_sql_threads
|
||||
|
|
|
|||
|
|
@ -804,6 +804,7 @@ def test_settings_json(app_client):
|
|||
"facet_suggest_time_limit_ms": 50,
|
||||
"facet_time_limit_ms": 200,
|
||||
"max_returned_rows": 100,
|
||||
"max_insert_rows": 100,
|
||||
"sql_time_limit_ms": 200,
|
||||
"allow_download": True,
|
||||
"allow_signed_tokens": True,
|
||||
|
|
|
|||
|
|
@ -18,11 +18,7 @@ def ds_write(tmp_path_factory):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_row(ds_write):
|
||||
token = "dstok_{}".format(
|
||||
ds_write.sign(
|
||||
{"a": "root", "token": "dstok", "t": int(time.time())}, namespace="token"
|
||||
)
|
||||
)
|
||||
token = write_token(ds_write)
|
||||
response = await ds_write.client.post(
|
||||
"/data/docs/-/insert",
|
||||
json={"row": {"title": "Test", "score": 1.0}},
|
||||
|
|
@ -36,3 +32,165 @@ async def test_write_row(ds_write):
|
|||
assert response.json()["inserted"] == [expected_row]
|
||||
rows = (await ds_write.get_database("data").execute("select * from docs")).rows
|
||||
assert dict(rows[0]) == expected_row
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("return_rows", (True, False))
|
||||
async def test_write_rows(ds_write, return_rows):
|
||||
token = write_token(ds_write)
|
||||
data = {"rows": [{"title": "Test {}".format(i), "score": 1.0} for i in range(20)]}
|
||||
if return_rows:
|
||||
data["return_rows"] = True
|
||||
response = await ds_write.client.post(
|
||||
"/data/docs/-/insert",
|
||||
json=data,
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(token),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
actual_rows = [
|
||||
dict(r)
|
||||
for r in (
|
||||
await ds_write.get_database("data").execute("select * from docs")
|
||||
).rows
|
||||
]
|
||||
assert len(actual_rows) == 20
|
||||
assert actual_rows == [
|
||||
{"id": i + 1, "title": "Test {}".format(i), "score": 1.0} for i in range(20)
|
||||
]
|
||||
assert response.json()["ok"] is True
|
||||
if return_rows:
|
||||
assert response.json()["inserted"] == actual_rows
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"path,input,special_case,expected_status,expected_errors",
|
||||
(
|
||||
(
|
||||
"/data2/docs/-/insert",
|
||||
{},
|
||||
None,
|
||||
404,
|
||||
["Database not found: data2"],
|
||||
),
|
||||
(
|
||||
"/data/docs2/-/insert",
|
||||
{},
|
||||
None,
|
||||
404,
|
||||
["Table not found: docs2"],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"rows": [{"title": "Test"} for i in range(10)]},
|
||||
"bad_token",
|
||||
403,
|
||||
["Permission denied"],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{},
|
||||
"invalid_json",
|
||||
400,
|
||||
[
|
||||
"Invalid JSON: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"
|
||||
],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{},
|
||||
"invalid_content_type",
|
||||
400,
|
||||
["Invalid content-type, must be application/json"],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
[],
|
||||
None,
|
||||
400,
|
||||
["JSON must be a dictionary"],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"row": "blah"},
|
||||
None,
|
||||
400,
|
||||
['"row" must be a dictionary'],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"blah": "blah"},
|
||||
None,
|
||||
400,
|
||||
['JSON must have one or other of "row" or "rows"'],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"rows": "blah"},
|
||||
None,
|
||||
400,
|
||||
['"rows" must be a list'],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"rows": ["blah"]},
|
||||
None,
|
||||
400,
|
||||
['"rows" must be a list of dictionaries'],
|
||||
),
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"rows": [{"title": "Test"} for i in range(101)]},
|
||||
None,
|
||||
400,
|
||||
["Too many rows, maximum allowed is 100"],
|
||||
),
|
||||
# Validate columns of each row
|
||||
(
|
||||
"/data/docs/-/insert",
|
||||
{"rows": [{"title": "Test", "bad": 1, "worse": 2} for i in range(2)]},
|
||||
None,
|
||||
400,
|
||||
[
|
||||
"Row 0 has invalid columns: bad, worse",
|
||||
"Row 1 has invalid columns: bad, worse",
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
async def test_write_row_errors(
|
||||
ds_write, path, input, special_case, expected_status, expected_errors
|
||||
):
|
||||
token = write_token(ds_write)
|
||||
if special_case == "bad_token":
|
||||
token += "bad"
|
||||
kwargs = dict(
|
||||
json=input,
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(token),
|
||||
"Content-Type": "text/plain"
|
||||
if special_case == "invalid_content_type"
|
||||
else "application/json",
|
||||
},
|
||||
)
|
||||
if special_case == "invalid_json":
|
||||
del kwargs["json"]
|
||||
kwargs["content"] = "{bad json"
|
||||
response = await ds_write.client.post(
|
||||
path,
|
||||
**kwargs,
|
||||
)
|
||||
assert response.status_code == expected_status
|
||||
assert response.json()["ok"] is False
|
||||
assert response.json()["errors"] == expected_errors
|
||||
|
||||
|
||||
def write_token(ds):
|
||||
return "dstok_{}".format(
|
||||
ds.sign(
|
||||
{"a": "root", "token": "dstok", "t": int(time.time())}, namespace="token"
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue