diff --git a/datasette/database.py b/datasette/database.py index 8b824462..7364ff7f 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,6 +1,7 @@ import asyncio import atexit from collections import namedtuple +import inspect import os from pathlib import Path import janus @@ -263,15 +264,21 @@ class Database: def _wrap_fn_with_hooks(self, fn, request, transaction, track_event): from .plugins import pm - # Wrap fn so it receives track_event if its signature supports it + # Wrap fn so it receives track_event if its signature supports it. + # Historically fn was called positionally, so any single-parameter + # name (conn, connection, db, ...) worked. Preserve that by only + # switching to keyword dependency injection when the callback + # explicitly opts in by declaring a `track_event` parameter. original_fn = fn - def fn_with_track_event(conn): - return call_with_supported_arguments( - original_fn, conn=conn, track_event=track_event - ) + if "track_event" in inspect.signature(original_fn).parameters: - fn = fn_with_track_event + def fn_with_track_event(conn): + return call_with_supported_arguments( + original_fn, conn=conn, track_event=track_event + ) + + fn = fn_with_track_event wrappers = pm.hook.write_wrapper( datasette=self.ds, diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py index e3d35f57..0d565d61 100644 --- a/tests/test_internals_database.py +++ b/tests/test_internals_database.py @@ -539,6 +539,37 @@ async def test_execute_write_fn_exception(db): await db.execute_write_fn(write_fn) +@pytest.mark.asyncio +@pytest.mark.parametrize("param_name", ["conn", "connection", "db", "c"]) +async def test_execute_write_fn_accepts_any_single_param_name(db, param_name): + # Plugins historically relied on the fact that the callback was invoked + # positionally, so any parameter name worked. Preserve that contract. + scope = {} + exec( + "def write_fn({0}):\n" + " return {0}.execute('select 1 + 1').fetchone()[0]".format(param_name), + scope, + ) + write_fn = scope["write_fn"] + result = await db.execute_write_fn(write_fn) + assert result == 2 + + +@pytest.mark.asyncio +async def test_execute_write_fn_with_track_event(db): + # When the callback declares track_event it still receives both args + # via dependency injection. + seen = [] + + def write_fn(conn, track_event): + seen.append(track_event) + return conn.execute("select 1 + 1").fetchone()[0] + + result = await db.execute_write_fn(write_fn) + assert result == 2 + assert len(seen) == 1 and callable(seen[0]) + + @pytest.mark.asyncio @pytest.mark.timeout(1) async def test_execute_write_fn_connection_exception(tmpdir, app_client):