named_parameters(sql) sync function, refs #2354

Also refs #2353 and #2352
This commit is contained in:
Simon Willison 2024-06-12 16:51:07 -07:00
commit d118d5c5bb
4 changed files with 37 additions and 20 deletions

View file

@ -1131,23 +1131,38 @@ class StartupError(Exception):
pass
_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
_single_line_comment_re = re.compile(r"--.*")
_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL)
_single_quote_re = re.compile(r"'(?:''|[^'])*'")
_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"')
_named_param_re = re.compile(r":(\w+)")
@documented
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
def named_parameters(sql: str) -> List[str]:
"""
Given a SQL statement, return a list of named parameters that are used in the statement
e.g. for ``select * from foo where id=:id`` this would return ``["id"]``
"""
explain = "explain {}".format(sql.strip().rstrip(";"))
possible_params = _re_named_parameter.findall(sql)
try:
results = await db.execute(explain, {p: None for p in possible_params})
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"]
except (sqlite3.DatabaseError, AttributeError):
return possible_params
# Remove single-line comments
sql = _single_line_comment_re.sub("", sql)
# Remove multi-line comments
sql = _multi_line_comment_re.sub("", sql)
# Remove single-quoted strings
sql = _single_quote_re.sub("", sql)
# Remove double-quoted strings
sql = _double_quote_re.sub("", sql)
# Extract parameters from what is left
return _named_param_re.findall(sql)
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
"""
This undocumented but stable method exists for backwards compatibility
with plugins that were using it before it switched to named_parameters()
"""
return named_parameters(sql)
def add_cors_headers(headers):

View file

@ -17,7 +17,7 @@ from datasette.utils import (
add_cors_headers,
await_me_maybe,
call_with_supported_arguments,
derive_named_parameters,
named_parameters as derive_named_parameters,
format_bytes,
make_slot_function,
tilde_decode,
@ -484,9 +484,7 @@ class QueryView(View):
if canned_query and canned_query.get("params"):
named_parameters = canned_query["params"]
if not named_parameters:
named_parameters = await derive_named_parameters(
datasette.get_database(database), sql
)
named_parameters = derive_named_parameters(sql)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters

View file

@ -1256,14 +1256,14 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth
.. autofunction:: datasette.utils.await_me_maybe
.. _internals_utils_derive_named_parameters:
.. _internals_utils_named_parameters:
derive_named_parameters(db, sql)
--------------------------------
named_parameters(sql)
---------------------
Derive the list of named parameters referenced in a SQL query, using an ``explain`` query executed against the provided database.
Derive the list of ``:named`` parameters referenced in a SQL query.
.. autofunction:: datasette.utils.derive_named_parameters
.. autofunction:: datasette.utils.named_parameters
.. _internals_tilde_encoding:

View file

@ -612,10 +612,14 @@ def test_parse_metadata(content, expected):
("select this is invalid :one, :two, :three", ["one", "two", "three"]),
),
)
async def test_derive_named_parameters(sql, expected):
@pytest.mark.parametrize("use_async_version", (False, True))
async def test_named_parameters(sql, expected, use_async_version):
ds = Datasette([], memory=True)
db = ds.get_database("_memory")
params = await utils.derive_named_parameters(db, sql)
if use_async_version:
params = await utils.derive_named_parameters(db, sql)
else:
params = utils.named_parameters(sql)
assert params == expected