From 31fb006a9b05067a8eb2f774ad3a3b15b4565924 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 30 May 2020 07:28:29 -0700 Subject: [PATCH] Added datasette.get_database() method Refs #576 --- datasette/app.py | 5 +++++ docs/internals.rst | 10 ++++++++++ docs/plugins.rst | 2 +- tests/test_database.py | 3 +++ tests/test_internals_datasette.py | 23 +++++++++++++++++++++++ 5 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 tests/test_internals_datasette.py diff --git a/datasette/app.py b/datasette/app.py index 07190c16..30eb3dba 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -281,6 +281,11 @@ class Datasette: self.register_renderers() + def get_database(self, name=None): + if name is None: + return next(iter(self.databases.values())) + return self.databases[name] + def add_database(self, name, db): self.databases[name] = db diff --git a/docs/internals.rst b/docs/internals.rst index ea015dbc..886cb7e7 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -44,6 +44,16 @@ This method lets you read plugin configuration values that were set in ``metadat Renders a `Jinja template `__ using Datasette's preconfigured instance of Jinja and returns the resulting string. The template will have access to Datasette's default template functions and any functions that have been made available by other plugins. +.. _datasette_get_database: + +.get_database(name) +------------------- + +``name`` - string, optional + The name of the database - optional. + +Returns the specified database object. Raises a ``KeyError`` if the database does not exist. Call this method without an argument to return the first connected database. + .. _datasette_add_database: .add_database(name, db) diff --git a/docs/plugins.rst b/docs/plugins.rst index b27daf3f..f08f1217 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -811,7 +811,7 @@ Here is a more complex example: .. code-block:: python async def render_demo(datasette, columns, rows): - db = next(iter(datasette.databases.values())) + db = datasette.get_database() result = await db.execute("select sqlite_version()") first_row = " | ".join(columns) lines = [first_row] diff --git a/tests/test_database.py b/tests/test_database.py index 1f1a3a7e..bd7e7666 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,3 +1,6 @@ +""" +Tests for the datasette.database.Database class +""" from datasette.database import Results, MultipleValues from datasette.utils import sqlite3 from .fixtures import app_client diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py new file mode 100644 index 00000000..4993250d --- /dev/null +++ b/tests/test_internals_datasette.py @@ -0,0 +1,23 @@ +""" +Tests for the datasette.app.Datasette class +""" +from .fixtures import app_client +import pytest + + +@pytest.fixture +def datasette(app_client): + return app_client.ds + + +def test_get_database(datasette): + db = datasette.get_database("fixtures") + assert "fixtures" == db.name + with pytest.raises(KeyError): + datasette.get_database("missing") + + +def test_get_database_no_argument(datasette): + # Returns the first available database: + db = datasette.get_database() + assert "fixtures" == db.name