named_parameters(sql) sync function, refs #2354

Also refs #2353 and #2352
This commit is contained in:
Simon Willison 2024-06-12 16:51:07 -07:00
commit d118d5c5bb
4 changed files with 37 additions and 20 deletions

View file

@ -1131,23 +1131,38 @@ class StartupError(Exception):
pass pass
_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") _single_line_comment_re = re.compile(r"--.*")
_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL)
_single_quote_re = re.compile(r"'(?:''|[^'])*'")
_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"')
_named_param_re = re.compile(r":(\w+)")
@documented @documented
async def derive_named_parameters(db: "Database", sql: str) -> List[str]: def named_parameters(sql: str) -> List[str]:
""" """
Given a SQL statement, return a list of named parameters that are used in the statement Given a SQL statement, return a list of named parameters that are used in the statement
e.g. for ``select * from foo where id=:id`` this would return ``["id"]`` e.g. for ``select * from foo where id=:id`` this would return ``["id"]``
""" """
explain = "explain {}".format(sql.strip().rstrip(";")) # Remove single-line comments
possible_params = _re_named_parameter.findall(sql) sql = _single_line_comment_re.sub("", sql)
try: # Remove multi-line comments
results = await db.execute(explain, {p: None for p in possible_params}) sql = _multi_line_comment_re.sub("", sql)
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"] # Remove single-quoted strings
except (sqlite3.DatabaseError, AttributeError): sql = _single_quote_re.sub("", sql)
return possible_params # Remove double-quoted strings
sql = _double_quote_re.sub("", sql)
# Extract parameters from what is left
return _named_param_re.findall(sql)
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
"""
This undocumented but stable method exists for backwards compatibility
with plugins that were using it before it switched to named_parameters()
"""
return named_parameters(sql)
def add_cors_headers(headers): def add_cors_headers(headers):

View file

@ -17,7 +17,7 @@ from datasette.utils import (
add_cors_headers, add_cors_headers,
await_me_maybe, await_me_maybe,
call_with_supported_arguments, call_with_supported_arguments,
derive_named_parameters, named_parameters as derive_named_parameters,
format_bytes, format_bytes,
make_slot_function, make_slot_function,
tilde_decode, tilde_decode,
@ -484,9 +484,7 @@ class QueryView(View):
if canned_query and canned_query.get("params"): if canned_query and canned_query.get("params"):
named_parameters = canned_query["params"] named_parameters = canned_query["params"]
if not named_parameters: if not named_parameters:
named_parameters = await derive_named_parameters( named_parameters = derive_named_parameters(sql)
datasette.get_database(database), sql
)
named_parameter_values = { named_parameter_values = {
named_parameter: params.get(named_parameter) or "" named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters for named_parameter in named_parameters

View file

@ -1256,14 +1256,14 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth
.. autofunction:: datasette.utils.await_me_maybe .. autofunction:: datasette.utils.await_me_maybe
.. _internals_utils_derive_named_parameters: .. _internals_utils_named_parameters:
derive_named_parameters(db, sql) named_parameters(sql)
-------------------------------- ---------------------
Derive the list of named parameters referenced in a SQL query, using an ``explain`` query executed against the provided database. Derive the list of ``:named`` parameters referenced in a SQL query.
.. autofunction:: datasette.utils.derive_named_parameters .. autofunction:: datasette.utils.named_parameters
.. _internals_tilde_encoding: .. _internals_tilde_encoding:

View file

@ -612,10 +612,14 @@ def test_parse_metadata(content, expected):
("select this is invalid :one, :two, :three", ["one", "two", "three"]), ("select this is invalid :one, :two, :three", ["one", "two", "three"]),
), ),
) )
async def test_derive_named_parameters(sql, expected): @pytest.mark.parametrize("use_async_version", (False, True))
async def test_named_parameters(sql, expected, use_async_version):
ds = Datasette([], memory=True) ds = Datasette([], memory=True)
db = ds.get_database("_memory") db = ds.get_database("_memory")
params = await utils.derive_named_parameters(db, sql) if use_async_version:
params = await utils.derive_named_parameters(db, sql)
else:
params = utils.named_parameters(sql)
assert params == expected assert params == expected