mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
Databases can now have a .route separate from their .name, refs #1668
This commit is contained in:
parent
798f075ef9
commit
7a6654a253
8 changed files with 111 additions and 26 deletions
|
|
@ -388,13 +388,18 @@ class Datasette:
|
|||
def unsign(self, signed, namespace="default"):
|
||||
return URLSafeSerializer(self._secret, namespace).loads(signed)
|
||||
|
||||
def get_database(self, name=None):
|
||||
def get_database(self, name=None, route=None):
|
||||
if route is not None:
|
||||
matches = [db for db in self.databases.values() if db.route == route]
|
||||
if not matches:
|
||||
raise KeyError
|
||||
return matches[0]
|
||||
if name is None:
|
||||
# Return first no-_schemas database
|
||||
# Return first database that isn't "_internal"
|
||||
name = [key for key in self.databases.keys() if key != "_internal"][0]
|
||||
return self.databases[name]
|
||||
|
||||
def add_database(self, db, name=None):
|
||||
def add_database(self, db, name=None, route=None):
|
||||
new_databases = self.databases.copy()
|
||||
if name is None:
|
||||
# Pick a unique name for this database
|
||||
|
|
@ -407,6 +412,7 @@ class Datasette:
|
|||
name = "{}_{}".format(suggestion, i)
|
||||
i += 1
|
||||
db.name = name
|
||||
db.route = route or name
|
||||
new_databases[name] = db
|
||||
# don't mutate! that causes race conditions with live import
|
||||
self.databases = new_databases
|
||||
|
|
@ -693,6 +699,7 @@ class Datasette:
|
|||
return [
|
||||
{
|
||||
"name": d.name,
|
||||
"route": d.route,
|
||||
"path": d.path,
|
||||
"size": d.size,
|
||||
"is_mutable": d.is_mutable,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class Database:
|
|||
self, ds, path=None, is_mutable=False, is_memory=False, memory_name=None
|
||||
):
|
||||
self.name = None
|
||||
self.route = None
|
||||
self.ds = ds
|
||||
self.path = path
|
||||
self.is_mutable = is_mutable
|
||||
|
|
|
|||
|
|
@ -371,13 +371,19 @@ class DataView(BaseView):
|
|||
return AsgiStream(stream_fn, headers=headers, content_type=content_type)
|
||||
|
||||
async def get(self, request):
|
||||
db_name = request.url_vars["database"]
|
||||
database = tilde_decode(db_name)
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
|
||||
_format = request.url_vars["format"]
|
||||
data_kwargs = {}
|
||||
|
||||
if _format == "csv":
|
||||
return await self.as_csv(request, database)
|
||||
return await self.as_csv(request, database_route)
|
||||
|
||||
if _format is None:
|
||||
# HTML views default to expanding all foreign key labels
|
||||
|
|
|
|||
|
|
@ -32,7 +32,13 @@ class DatabaseView(DataView):
|
|||
name = "database"
|
||||
|
||||
async def data(self, request, default_labels=False, _size=None):
|
||||
database = tilde_decode(request.url_vars["database"])
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
|
||||
await self.check_permissions(
|
||||
request,
|
||||
[
|
||||
|
|
@ -50,11 +56,6 @@ class DatabaseView(DataView):
|
|||
request, sql, _size=_size, metadata=metadata
|
||||
)
|
||||
|
||||
try:
|
||||
db = self.ds.databases[database]
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database))
|
||||
|
||||
table_counts = await db.table_counts(5)
|
||||
hidden_table_names = set(await db.hidden_table_names())
|
||||
all_foreign_keys = await db.get_all_foreign_keys()
|
||||
|
|
@ -171,9 +172,10 @@ class DatabaseDownload(DataView):
|
|||
"view-instance",
|
||||
],
|
||||
)
|
||||
if database not in self.ds.databases:
|
||||
try:
|
||||
db = self.ds.get_database(route=database)
|
||||
except KeyError:
|
||||
raise DatasetteError("Invalid database", status=404)
|
||||
db = self.ds.databases[database]
|
||||
if db.is_memory:
|
||||
raise DatasetteError("Cannot download in-memory databases", status=404)
|
||||
if not self.ds.setting("allow_download") or db.is_mutable:
|
||||
|
|
|
|||
|
|
@ -272,10 +272,15 @@ class TableView(RowTableShared):
|
|||
name = "table"
|
||||
|
||||
async def post(self, request):
|
||||
db_name = tilde_decode(request.url_vars["database"])
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
table = tilde_decode(request.url_vars["table"])
|
||||
# Handle POST to a canned query
|
||||
canned_query = await self.ds.get_canned_query(db_name, table, request.actor)
|
||||
canned_query = await self.ds.get_canned_query(database, table, request.actor)
|
||||
assert canned_query, "You may only POST to a canned query"
|
||||
return await QueryView(self.ds).data(
|
||||
request,
|
||||
|
|
@ -327,12 +332,13 @@ class TableView(RowTableShared):
|
|||
_next=None,
|
||||
_size=None,
|
||||
):
|
||||
database = tilde_decode(request.url_vars["database"])
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
table = tilde_decode(request.url_vars["table"])
|
||||
try:
|
||||
db = self.ds.databases[database]
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database))
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
|
||||
# If this is a canned query, not a table, then dispatch to QueryView instead
|
||||
canned_query = await self.ds.get_canned_query(database, table, request.actor)
|
||||
|
|
@ -938,8 +944,13 @@ class RowView(RowTableShared):
|
|||
name = "row"
|
||||
|
||||
async def data(self, request, default_labels=False):
|
||||
database = tilde_decode(request.url_vars["database"])
|
||||
database_route = tilde_decode(request.url_vars["database"])
|
||||
table = tilde_decode(request.url_vars["table"])
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
await self.check_permissions(
|
||||
request,
|
||||
[
|
||||
|
|
@ -949,7 +960,11 @@ class RowView(RowTableShared):
|
|||
],
|
||||
)
|
||||
pk_values = urlsafe_components(request.url_vars["pks"])
|
||||
db = self.ds.databases[database]
|
||||
try:
|
||||
db = self.ds.get_database(route=database_route)
|
||||
except KeyError:
|
||||
raise NotFound("Database not found: {}".format(database_route))
|
||||
database = db.name
|
||||
sql, params, pks = await _sql_params_pks(db, table, pk_values)
|
||||
results = await db.execute(sql, params, truncate=True)
|
||||
columns = [r[0] for r in results.description]
|
||||
|
|
|
|||
|
|
@ -307,14 +307,17 @@ Returns the specified database object. Raises a ``KeyError`` if the database doe
|
|||
|
||||
.. _datasette_add_database:
|
||||
|
||||
.add_database(db, name=None)
|
||||
----------------------------
|
||||
.add_database(db, name=None, route=None)
|
||||
----------------------------------------
|
||||
|
||||
``db`` - datasette.database.Database instance
|
||||
The database to be attached.
|
||||
|
||||
``name`` - string, optional
|
||||
The name to be used for this database - this will be used in the URL path, e.g. ``/dbname``. If not specified Datasette will pick one based on the filename or memory name.
|
||||
The name to be used for this database . If not specified Datasette will pick one based on the filename or memory name.
|
||||
|
||||
``route`` - string, optional
|
||||
This will be used in the URL path. If not specified, it will default to the same thing as the ``name``.
|
||||
|
||||
The ``datasette.add_database(db)`` method lets you add a new database to the current Datasette instance.
|
||||
|
||||
|
|
@ -371,7 +374,7 @@ Using either of these pattern will result in the in-memory database being served
|
|||
``name`` - string
|
||||
The name of the database to be removed.
|
||||
|
||||
This removes a database that has been previously added. ``name=`` is the unique name of that database, used in its URL path.
|
||||
This removes a database that has been previously added. ``name=`` is the unique name of that database.
|
||||
|
||||
.. _datasette_sign:
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ async def test_datasette_constructor():
|
|||
assert databases == [
|
||||
{
|
||||
"name": "_memory",
|
||||
"route": "_memory",
|
||||
"path": None,
|
||||
"size": 0,
|
||||
"is_mutable": False,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from datasette.app import Datasette
|
||||
from datasette.app import Datasette, Database
|
||||
from datasette.utils import resolve_routes
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -53,3 +54,52 @@ def test_routes(routes, path, expected_class, expected_matches):
|
|||
else:
|
||||
assert view.view_class.__name__ == expected_class
|
||||
assert match.groupdict() == expected_matches
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def ds_with_route():
|
||||
ds = Datasette()
|
||||
ds.remove_database("_memory")
|
||||
db = Database(ds, is_memory=True, memory_name="route-name-db")
|
||||
ds.add_database(db, name="name", route="route-name")
|
||||
await db.execute_write_script(
|
||||
"""
|
||||
create table if not exists t (id integer primary key);
|
||||
insert or replace into t (id) values (1);
|
||||
"""
|
||||
)
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_with_route_databases(ds_with_route):
|
||||
response = await ds_with_route.client.get("/-/databases.json")
|
||||
assert response.json()[0] == {
|
||||
"name": "name",
|
||||
"route": "route-name",
|
||||
"path": None,
|
||||
"size": 0,
|
||||
"is_mutable": True,
|
||||
"is_memory": True,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"path,expected_status",
|
||||
(
|
||||
("/", 200),
|
||||
("/name", 404),
|
||||
("/name/t", 404),
|
||||
("/name/t/1", 404),
|
||||
("/route-name", 200),
|
||||
("/route-name/t", 200),
|
||||
("/route-name/t/1", 200),
|
||||
),
|
||||
)
|
||||
async def test_db_with_route_that_does_not_match_name(
|
||||
ds_with_route, path, expected_status
|
||||
):
|
||||
response = await ds_with_route.client.get(path)
|
||||
assert response.status_code == expected_status
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue