Include count in execute_write_many traces, closes #1571

This commit is contained in:
Simon Willison 2021-12-19 12:30:34 -08:00
commit f65817000f
3 changed files with 28 additions and 9 deletions

View file

@ -114,11 +114,22 @@ class Database:
async def execute_write_many(self, sql, params_seq, block=False):
def _inner(conn):
with conn:
return conn.executemany(sql, params_seq)
count = 0
with trace("sql", database=self.name, sql=sql.strip(), executemany=True):
results = await self.execute_write_fn(_inner, block=block)
def count_params(params):
nonlocal count
for param in params:
count += 1
yield param
with conn:
return conn.executemany(sql, count_params(params_seq)), count
with trace(
"sql", database=self.name, sql=sql.strip(), executemany=True
) as kwargs:
results, count = await self.execute_write_fn(_inner, block=block)
kwargs["count"] = count
return results
async def execute_write_fn(self, fn, block=False):

View file

@ -32,14 +32,14 @@ def trace(type, **kwargs):
), f".trace() keyword parameters cannot include {TRACE_RESERVED_KEYS}"
task_id = get_task_id()
if task_id is None:
yield
yield kwargs
return
tracer = tracers.get(task_id)
if tracer is None:
yield
yield kwargs
return
start = time.perf_counter()
yield
yield kwargs
end = time.perf_counter()
trace_info = {
"type": type,

View file

@ -928,8 +928,9 @@ def test_trace(trace_debug):
assert isinstance(trace_info["sum_trace_duration_ms"], float)
assert isinstance(trace_info["num_traces"], int)
assert isinstance(trace_info["traces"], list)
assert len(trace_info["traces"]) == trace_info["num_traces"]
for trace in trace_info["traces"]:
traces = trace_info["traces"]
assert len(traces) == trace_info["num_traces"]
for trace in traces:
assert isinstance(trace["type"], str)
assert isinstance(trace["start"], float)
assert isinstance(trace["end"], float)
@ -939,7 +940,7 @@ def test_trace(trace_debug):
assert isinstance(trace["sql"], str)
assert isinstance(trace.get("params"), (list, dict, None.__class__))
sqls = [trace["sql"] for trace in trace_info["traces"] if "sql" in trace]
sqls = [trace["sql"] for trace in traces if "sql" in trace]
# There should be a mix of different types of SQL statement
expected = (
"CREATE TABLE ",
@ -954,6 +955,13 @@ def test_trace(trace_debug):
sql.startswith(prefix) for sql in sqls
), "No trace beginning with: {}".format(prefix)
# Should be at least one executescript
assert any(trace for trace in traces if trace.get("executescript"))
# And at least one executemany
execute_manys = [trace for trace in traces if trace.get("executemany")]
assert execute_manys
assert all(isinstance(trace["count"], int) for trace in execute_manys)
@pytest.mark.parametrize(
"path,status_code",