diff --git a/datasette/database.py b/datasette/database.py index 1de1d5ec..0e41ff32 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -114,11 +114,22 @@ class Database: async def execute_write_many(self, sql, params_seq, block=False): def _inner(conn): - with conn: - return conn.executemany(sql, params_seq) + count = 0 - with trace("sql", database=self.name, sql=sql.strip(), executemany=True): - results = await self.execute_write_fn(_inner, block=block) + def count_params(params): + nonlocal count + for param in params: + count += 1 + yield param + + with conn: + return conn.executemany(sql, count_params(params_seq)), count + + with trace( + "sql", database=self.name, sql=sql.strip(), executemany=True + ) as kwargs: + results, count = await self.execute_write_fn(_inner, block=block) + kwargs["count"] = count return results async def execute_write_fn(self, fn, block=False): diff --git a/datasette/tracer.py b/datasette/tracer.py index 62c3c90c..6703f060 100644 --- a/datasette/tracer.py +++ b/datasette/tracer.py @@ -32,14 +32,14 @@ def trace(type, **kwargs): ), f".trace() keyword parameters cannot include {TRACE_RESERVED_KEYS}" task_id = get_task_id() if task_id is None: - yield + yield kwargs return tracer = tracers.get(task_id) if tracer is None: - yield + yield kwargs return start = time.perf_counter() - yield + yield kwargs end = time.perf_counter() trace_info = { "type": type, diff --git a/tests/test_api.py b/tests/test_api.py index f198c1f9..8ecaef43 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -928,8 +928,9 @@ def test_trace(trace_debug): assert isinstance(trace_info["sum_trace_duration_ms"], float) assert isinstance(trace_info["num_traces"], int) assert isinstance(trace_info["traces"], list) - assert len(trace_info["traces"]) == trace_info["num_traces"] - for trace in trace_info["traces"]: + traces = trace_info["traces"] + assert len(traces) == trace_info["num_traces"] + for trace in traces: assert isinstance(trace["type"], str) assert isinstance(trace["start"], float) assert isinstance(trace["end"], float) @@ -939,7 +940,7 @@ def test_trace(trace_debug): assert isinstance(trace["sql"], str) assert isinstance(trace.get("params"), (list, dict, None.__class__)) - sqls = [trace["sql"] for trace in trace_info["traces"] if "sql" in trace] + sqls = [trace["sql"] for trace in traces if "sql" in trace] # There should be a mix of different types of SQL statement expected = ( "CREATE TABLE ", @@ -954,6 +955,13 @@ def test_trace(trace_debug): sql.startswith(prefix) for sql in sqls ), "No trace beginning with: {}".format(prefix) + # Should be at least one executescript + assert any(trace for trace in traces if trace.get("executescript")) + # And at least one executemany + execute_manys = [trace for trace in traces if trace.get("executemany")] + assert execute_manys + assert all(isinstance(trace["count"], int) for trace in execute_manys) + @pytest.mark.parametrize( "path,status_code",