Databases can now have a .route separate from their .name, refs #1668

This commit is contained in:
Simon Willison 2022-03-19 17:11:17 -07:00
commit 7a6654a253
8 changed files with 111 additions and 26 deletions

View file

@ -388,13 +388,18 @@ class Datasette:
def unsign(self, signed, namespace="default"):
return URLSafeSerializer(self._secret, namespace).loads(signed)
def get_database(self, name=None):
def get_database(self, name=None, route=None):
if route is not None:
matches = [db for db in self.databases.values() if db.route == route]
if not matches:
raise KeyError
return matches[0]
if name is None:
# Return first no-_schemas database
# Return first database that isn't "_internal"
name = [key for key in self.databases.keys() if key != "_internal"][0]
return self.databases[name]
def add_database(self, db, name=None):
def add_database(self, db, name=None, route=None):
new_databases = self.databases.copy()
if name is None:
# Pick a unique name for this database
@ -407,6 +412,7 @@ class Datasette:
name = "{}_{}".format(suggestion, i)
i += 1
db.name = name
db.route = route or name
new_databases[name] = db
# don't mutate! that causes race conditions with live import
self.databases = new_databases
@ -693,6 +699,7 @@ class Datasette:
return [
{
"name": d.name,
"route": d.route,
"path": d.path,
"size": d.size,
"is_mutable": d.is_mutable,

View file

@ -31,6 +31,7 @@ class Database:
self, ds, path=None, is_mutable=False, is_memory=False, memory_name=None
):
self.name = None
self.route = None
self.ds = ds
self.path = path
self.is_mutable = is_mutable

View file

@ -371,13 +371,19 @@ class DataView(BaseView):
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
async def get(self, request):
db_name = request.url_vars["database"]
database = tilde_decode(db_name)
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
_format = request.url_vars["format"]
data_kwargs = {}
if _format == "csv":
return await self.as_csv(request, database)
return await self.as_csv(request, database_route)
if _format is None:
# HTML views default to expanding all foreign key labels

View file

@ -32,7 +32,13 @@ class DatabaseView(DataView):
name = "database"
async def data(self, request, default_labels=False, _size=None):
database = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
await self.check_permissions(
request,
[
@ -50,11 +56,6 @@ class DatabaseView(DataView):
request, sql, _size=_size, metadata=metadata
)
try:
db = self.ds.databases[database]
except KeyError:
raise NotFound("Database not found: {}".format(database))
table_counts = await db.table_counts(5)
hidden_table_names = set(await db.hidden_table_names())
all_foreign_keys = await db.get_all_foreign_keys()
@ -171,9 +172,10 @@ class DatabaseDownload(DataView):
"view-instance",
],
)
if database not in self.ds.databases:
try:
db = self.ds.get_database(route=database)
except KeyError:
raise DatasetteError("Invalid database", status=404)
db = self.ds.databases[database]
if db.is_memory:
raise DatasetteError("Cannot download in-memory databases", status=404)
if not self.ds.setting("allow_download") or db.is_mutable:

View file

@ -272,10 +272,15 @@ class TableView(RowTableShared):
name = "table"
async def post(self, request):
db_name = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
table = tilde_decode(request.url_vars["table"])
# Handle POST to a canned query
canned_query = await self.ds.get_canned_query(db_name, table, request.actor)
canned_query = await self.ds.get_canned_query(database, table, request.actor)
assert canned_query, "You may only POST to a canned query"
return await QueryView(self.ds).data(
request,
@ -327,12 +332,13 @@ class TableView(RowTableShared):
_next=None,
_size=None,
):
database = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.databases[database]
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database))
raise NotFound("Database not found: {}".format(database_route))
database = db.name
# If this is a canned query, not a table, then dispatch to QueryView instead
canned_query = await self.ds.get_canned_query(database, table, request.actor)
@ -938,8 +944,13 @@ class RowView(RowTableShared):
name = "row"
async def data(self, request, default_labels=False):
database = tilde_decode(request.url_vars["database"])
database_route = tilde_decode(request.url_vars["database"])
table = tilde_decode(request.url_vars["table"])
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
await self.check_permissions(
request,
[
@ -949,7 +960,11 @@ class RowView(RowTableShared):
],
)
pk_values = urlsafe_components(request.url_vars["pks"])
db = self.ds.databases[database]
try:
db = self.ds.get_database(route=database_route)
except KeyError:
raise NotFound("Database not found: {}".format(database_route))
database = db.name
sql, params, pks = await _sql_params_pks(db, table, pk_values)
results = await db.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]

View file

@ -307,14 +307,17 @@ Returns the specified database object. Raises a ``KeyError`` if the database doe
.. _datasette_add_database:
.add_database(db, name=None)
----------------------------
.add_database(db, name=None, route=None)
----------------------------------------
``db`` - datasette.database.Database instance
The database to be attached.
``name`` - string, optional
The name to be used for this database - this will be used in the URL path, e.g. ``/dbname``. If not specified Datasette will pick one based on the filename or memory name.
The name to be used for this database . If not specified Datasette will pick one based on the filename or memory name.
``route`` - string, optional
This will be used in the URL path. If not specified, it will default to the same thing as the ``name``.
The ``datasette.add_database(db)`` method lets you add a new database to the current Datasette instance.
@ -371,7 +374,7 @@ Using either of these pattern will result in the in-memory database being served
``name`` - string
The name of the database to be removed.
This removes a database that has been previously added. ``name=`` is the unique name of that database, used in its URL path.
This removes a database that has been previously added. ``name=`` is the unique name of that database.
.. _datasette_sign:

View file

@ -55,6 +55,7 @@ async def test_datasette_constructor():
assert databases == [
{
"name": "_memory",
"route": "_memory",
"path": None,
"size": 0,
"is_mutable": False,

View file

@ -1,6 +1,7 @@
from datasette.app import Datasette
from datasette.app import Datasette, Database
from datasette.utils import resolve_routes
import pytest
import pytest_asyncio
@pytest.fixture(scope="session")
@ -53,3 +54,52 @@ def test_routes(routes, path, expected_class, expected_matches):
else:
assert view.view_class.__name__ == expected_class
assert match.groupdict() == expected_matches
@pytest_asyncio.fixture
async def ds_with_route():
ds = Datasette()
ds.remove_database("_memory")
db = Database(ds, is_memory=True, memory_name="route-name-db")
ds.add_database(db, name="name", route="route-name")
await db.execute_write_script(
"""
create table if not exists t (id integer primary key);
insert or replace into t (id) values (1);
"""
)
return ds
@pytest.mark.asyncio
async def test_db_with_route_databases(ds_with_route):
response = await ds_with_route.client.get("/-/databases.json")
assert response.json()[0] == {
"name": "name",
"route": "route-name",
"path": None,
"size": 0,
"is_mutable": True,
"is_memory": True,
"hash": None,
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"path,expected_status",
(
("/", 200),
("/name", 404),
("/name/t", 404),
("/name/t/1", 404),
("/route-name", 200),
("/route-name/t", 200),
("/route-name/t/1", 200),
),
)
async def test_db_with_route_that_does_not_match_name(
ds_with_route, path, expected_status
):
response = await ds_with_route.client.get(path)
assert response.status_code == expected_status