mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
/db/table/-/upsert API
Close #1878 Also made a few tweaks to how _r works in tokens and actors, refs #1855 - I needed that mechanism for the tests.
This commit is contained in:
parent
93ababe6f7
commit
272982e8a6
6 changed files with 401 additions and 44 deletions
|
|
@ -40,7 +40,7 @@ from .views.special import (
|
|||
PermissionsDebugView,
|
||||
MessagesDebugView,
|
||||
)
|
||||
from .views.table import TableView, TableInsertView, TableDropView
|
||||
from .views.table import TableView, TableInsertView, TableUpsertView, TableDropView
|
||||
from .views.row import RowView, RowDeleteView, RowUpdateView
|
||||
from .renderer import json_renderer
|
||||
from .url_builder import Urls
|
||||
|
|
@ -1292,6 +1292,10 @@ class Datasette:
|
|||
TableInsertView.as_view(self),
|
||||
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/insert$",
|
||||
)
|
||||
add_route(
|
||||
TableUpsertView.as_view(self),
|
||||
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/upsert$",
|
||||
)
|
||||
add_route(
|
||||
TableDropView.as_view(self),
|
||||
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/drop$",
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def permission_allowed_actor_restrictions(actor, action, resource):
|
|||
if action_initials in database_allowed:
|
||||
return None
|
||||
# Or the current table? That's any time the resource is (database, table)
|
||||
if not isinstance(resource, str) and len(resource) == 2:
|
||||
if resource is not None and not isinstance(resource, str) and len(resource) == 2:
|
||||
database, table = resource
|
||||
table_allowed = _r.get("t", {}).get(database, {}).get(table)
|
||||
# TODO: What should this do for canned queries?
|
||||
|
|
@ -138,6 +138,8 @@ def actor_from_request(datasette, request):
|
|||
# Expired
|
||||
return None
|
||||
actor = {"id": decoded["a"], "token": "dstok"}
|
||||
if "_r" in decoded:
|
||||
actor["_r"] = decoded["_r"]
|
||||
if duration:
|
||||
actor["token_expires"] = created + duration
|
||||
return actor
|
||||
|
|
|
|||
|
|
@ -316,21 +316,37 @@ class ApiExplorerView(BaseView):
|
|||
request.actor, "insert-row", (name, table)
|
||||
):
|
||||
pks = await db.primary_keys(table)
|
||||
table_links.append(
|
||||
{
|
||||
"path": self.ds.urls.table(name, table) + "/-/insert",
|
||||
"method": "POST",
|
||||
"label": "Insert rows into {}".format(table),
|
||||
"json": {
|
||||
"rows": [
|
||||
{
|
||||
column: None
|
||||
for column in await db.table_columns(table)
|
||||
if column not in pks
|
||||
}
|
||||
]
|
||||
table_links.extend(
|
||||
[
|
||||
{
|
||||
"path": self.ds.urls.table(name, table) + "/-/insert",
|
||||
"method": "POST",
|
||||
"label": "Insert rows into {}".format(table),
|
||||
"json": {
|
||||
"rows": [
|
||||
{
|
||||
column: None
|
||||
for column in await db.table_columns(table)
|
||||
if column not in pks
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
{
|
||||
"path": self.ds.urls.table(name, table) + "/-/upsert",
|
||||
"method": "POST",
|
||||
"label": "Upsert rows into {}".format(table),
|
||||
"json": {
|
||||
"rows": [
|
||||
{
|
||||
column: None
|
||||
for column in await db.table_columns(table)
|
||||
if column not in pks
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
if await self.ds.permission_allowed(
|
||||
request.actor, "drop-table", (name, table)
|
||||
|
|
|
|||
|
|
@ -1074,9 +1074,18 @@ class TableInsertView(BaseView):
|
|||
def __init__(self, datasette):
|
||||
self.ds = datasette
|
||||
|
||||
async def _validate_data(self, request, db, table_name):
|
||||
async def _validate_data(self, request, db, table_name, pks, upsert):
|
||||
errors = []
|
||||
|
||||
pks_list = []
|
||||
if isinstance(pks, str):
|
||||
pks_list = [pks]
|
||||
else:
|
||||
pks_list = list(pks)
|
||||
|
||||
if not pks_list:
|
||||
pks_list = ["rowid"]
|
||||
|
||||
def _errors(errors):
|
||||
return None, errors, {}
|
||||
|
||||
|
|
@ -1134,7 +1143,18 @@ class TableInsertView(BaseView):
|
|||
|
||||
# Validate columns of each row
|
||||
columns = set(await db.table_columns(table_name))
|
||||
columns.update(pks_list)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if upsert:
|
||||
# It MUST have the primary key
|
||||
missing_pks = [pk for pk in pks_list if pk not in row]
|
||||
if missing_pks:
|
||||
errors.append(
|
||||
'Row {} is missing primary key column(s): "{}"'.format(
|
||||
i, '", "'.join(missing_pks)
|
||||
)
|
||||
)
|
||||
invalid_columns = set(row.keys()) - columns
|
||||
if invalid_columns:
|
||||
errors.append(
|
||||
|
|
@ -1146,7 +1166,7 @@ class TableInsertView(BaseView):
|
|||
return _errors(errors)
|
||||
return rows, errors, extras
|
||||
|
||||
async def post(self, request):
|
||||
async def post(self, request, upsert=False):
|
||||
try:
|
||||
resolved = await self.ds.resolve_table(request)
|
||||
except NotFound as e:
|
||||
|
|
@ -1159,28 +1179,66 @@ class TableInsertView(BaseView):
|
|||
db = self.ds.get_database(database_name)
|
||||
if not await db.table_exists(table_name):
|
||||
return _error(["Table not found: {}".format(table_name)], 404)
|
||||
# Must have insert-row permission
|
||||
if not await self.ds.permission_allowed(
|
||||
request.actor, "insert-row", resource=(database_name, table_name)
|
||||
):
|
||||
return _error(["Permission denied"], 403)
|
||||
rows, errors, extras = await self._validate_data(request, db, table_name)
|
||||
|
||||
if upsert:
|
||||
# Must have insert-row AND upsert-row permissions
|
||||
if not (
|
||||
await self.ds.permission_allowed(
|
||||
request.actor, "insert-row", database_name, table_name
|
||||
)
|
||||
and await self.ds.permission_allowed(
|
||||
request.actor, "update-row", database_name, table_name
|
||||
)
|
||||
):
|
||||
return _error(
|
||||
["Permission denied: need both insert-row and update-row"], 403
|
||||
)
|
||||
else:
|
||||
# Must have insert-row permission
|
||||
if not await self.ds.permission_allowed(
|
||||
request.actor, "insert-row", resource=(database_name, table_name)
|
||||
):
|
||||
return _error(["Permission denied"], 403)
|
||||
|
||||
if not db.is_mutable:
|
||||
return _error(["Database is immutable"], 403)
|
||||
|
||||
pks = await db.primary_keys(table_name)
|
||||
|
||||
rows, errors, extras = await self._validate_data(
|
||||
request, db, table_name, pks, upsert
|
||||
)
|
||||
if errors:
|
||||
return _error(errors, 400)
|
||||
|
||||
# No that we've passed pks to _validate_data it's safe to
|
||||
# fix the rowids case:
|
||||
if not pks:
|
||||
pks = ["rowid"]
|
||||
|
||||
ignore = extras.get("ignore")
|
||||
replace = extras.get("replace")
|
||||
|
||||
if upsert and (ignore or replace):
|
||||
return _error(["Upsert does not support ignore or replace"], 400)
|
||||
|
||||
should_return = bool(extras.get("return", False))
|
||||
# Insert rows
|
||||
def insert_rows(conn):
|
||||
row_pk_values_for_later = []
|
||||
if should_return and upsert:
|
||||
row_pk_values_for_later = [tuple(row[pk] for pk in pks) for row in rows]
|
||||
|
||||
def insert_or_upsert_rows(conn):
|
||||
table = sqlite_utils.Database(conn)[table_name]
|
||||
if should_return:
|
||||
kwargs = {}
|
||||
if upsert:
|
||||
kwargs["pk"] = pks[0] if len(pks) == 1 else pks
|
||||
else:
|
||||
kwargs = {"ignore": ignore, "replace": replace}
|
||||
if should_return and not upsert:
|
||||
rowids = []
|
||||
method = table.upsert if upsert else table.insert
|
||||
for row in rows:
|
||||
rowids.append(
|
||||
table.insert(row, ignore=ignore, replace=replace).last_rowid
|
||||
)
|
||||
rowids.append(method(row, **kwargs).last_rowid)
|
||||
return list(
|
||||
table.rows_where(
|
||||
"rowid in ({})".format(",".join("?" for _ in rowids)),
|
||||
|
|
@ -1188,16 +1246,39 @@ class TableInsertView(BaseView):
|
|||
)
|
||||
)
|
||||
else:
|
||||
table.insert_all(rows, ignore=ignore, replace=replace)
|
||||
method_all = table.upsert_all if upsert else table.insert_all
|
||||
method_all(rows, **kwargs)
|
||||
|
||||
try:
|
||||
rows = await db.execute_write_fn(insert_rows)
|
||||
rows = await db.execute_write_fn(insert_or_upsert_rows)
|
||||
except Exception as e:
|
||||
return _error([str(e)])
|
||||
result = {"ok": True}
|
||||
if should_return:
|
||||
result["rows"] = rows
|
||||
return Response.json(result, status=201)
|
||||
if upsert:
|
||||
# Fetch based on initial input IDs
|
||||
where_clause = " OR ".join(
|
||||
["({})".format(" AND ".join("{} = ?".format(pk) for pk in pks))]
|
||||
* len(row_pk_values_for_later)
|
||||
)
|
||||
args = list(itertools.chain.from_iterable(row_pk_values_for_later))
|
||||
fetched_rows = await db.execute(
|
||||
"select {}* from [{}] where {}".format(
|
||||
"rowid, " if pks == ["rowid"] else "", table_name, where_clause
|
||||
),
|
||||
args,
|
||||
)
|
||||
result["rows"] = [dict(r) for r in fetched_rows.rows]
|
||||
else:
|
||||
result["rows"] = rows
|
||||
return Response.json(result, status=200 if upsert else 201)
|
||||
|
||||
|
||||
class TableUpsertView(TableInsertView):
|
||||
name = "table-upsert"
|
||||
|
||||
async def post(self, request):
|
||||
return await super().post(request, upsert=True)
|
||||
|
||||
|
||||
class TableDropView(BaseView):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue