Databases can now have a .route separate from their .name, refs #1668

This commit is contained in:
Simon Willison 2022-03-19 17:11:17 -07:00
commit 7a6654a253
8 changed files with 111 additions and 26 deletions

View file

@ -272,10 +272,15 @@ class TableView(RowTableShared):
name = "table"
async def post(self, request):
db_name = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
table = tilde_decode(request.url_vars["table"])
# Handle POST to a canned query
canned_query = await self.ds.get_canned_query(db_name, table, request.actor)
canned_query = await self.ds.get_canned_query(database, table, request.actor)
assert canned_query, "You may only POST to a canned query"
return await QueryView(self.ds).data(
request,
@ -327,12 +332,13 @@ class TableView(RowTableShared):
_next=None,
_size=None,
):
database = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.databases[database]
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database))
raise NotFound("Database not found: {}".format(database_route))
database = db.name
# If this is a canned query, not a table, then dispatch to QueryView instead
canned_query = await self.ds.get_canned_query(database, table, request.actor)
@ -938,8 +944,13 @@ class RowView(RowTableShared):
name = "row"
async def data(self, request, default_labels=False):
database = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
await self.check_permissions(
request,
[
@ -949,7 +960,11 @@ class RowView(RowTableShared):
],
)
pk_values = urlsafe_components(request.url_vars["pks"])
db = self.ds.databases[database]
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]