mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
named_parameters(sql) sync function, refs #2354
Also refs #2353 and #2352
This commit is contained in:
parent
b39b01a890
commit
d118d5c5bb
4 changed files with 37 additions and 20 deletions
|
|
@ -1131,23 +1131,38 @@ class StartupError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
|
||||
_single_line_comment_re = re.compile(r"--.*")
|
||||
_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL)
|
||||
_single_quote_re = re.compile(r"'(?:''|[^'])*'")
|
||||
_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"')
|
||||
_named_param_re = re.compile(r":(\w+)")
|
||||
|
||||
|
||||
@documented
|
||||
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
|
||||
def named_parameters(sql: str) -> List[str]:
|
||||
"""
|
||||
Given a SQL statement, return a list of named parameters that are used in the statement
|
||||
|
||||
e.g. for ``select * from foo where id=:id`` this would return ``["id"]``
|
||||
"""
|
||||
explain = "explain {}".format(sql.strip().rstrip(";"))
|
||||
possible_params = _re_named_parameter.findall(sql)
|
||||
try:
|
||||
results = await db.execute(explain, {p: None for p in possible_params})
|
||||
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"]
|
||||
except (sqlite3.DatabaseError, AttributeError):
|
||||
return possible_params
|
||||
# Remove single-line comments
|
||||
sql = _single_line_comment_re.sub("", sql)
|
||||
# Remove multi-line comments
|
||||
sql = _multi_line_comment_re.sub("", sql)
|
||||
# Remove single-quoted strings
|
||||
sql = _single_quote_re.sub("", sql)
|
||||
# Remove double-quoted strings
|
||||
sql = _double_quote_re.sub("", sql)
|
||||
# Extract parameters from what is left
|
||||
return _named_param_re.findall(sql)
|
||||
|
||||
|
||||
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
|
||||
"""
|
||||
This undocumented but stable method exists for backwards compatibility
|
||||
with plugins that were using it before it switched to named_parameters()
|
||||
"""
|
||||
return named_parameters(sql)
|
||||
|
||||
|
||||
def add_cors_headers(headers):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from datasette.utils import (
|
|||
add_cors_headers,
|
||||
await_me_maybe,
|
||||
call_with_supported_arguments,
|
||||
derive_named_parameters,
|
||||
named_parameters as derive_named_parameters,
|
||||
format_bytes,
|
||||
make_slot_function,
|
||||
tilde_decode,
|
||||
|
|
@ -484,9 +484,7 @@ class QueryView(View):
|
|||
if canned_query and canned_query.get("params"):
|
||||
named_parameters = canned_query["params"]
|
||||
if not named_parameters:
|
||||
named_parameters = await derive_named_parameters(
|
||||
datasette.get_database(database), sql
|
||||
)
|
||||
named_parameters = derive_named_parameters(sql)
|
||||
named_parameter_values = {
|
||||
named_parameter: params.get(named_parameter) or ""
|
||||
for named_parameter in named_parameters
|
||||
|
|
|
|||
|
|
@ -1256,14 +1256,14 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth
|
|||
|
||||
.. autofunction:: datasette.utils.await_me_maybe
|
||||
|
||||
.. _internals_utils_derive_named_parameters:
|
||||
.. _internals_utils_named_parameters:
|
||||
|
||||
derive_named_parameters(db, sql)
|
||||
--------------------------------
|
||||
named_parameters(sql)
|
||||
---------------------
|
||||
|
||||
Derive the list of named parameters referenced in a SQL query, using an ``explain`` query executed against the provided database.
|
||||
Derive the list of ``:named`` parameters referenced in a SQL query.
|
||||
|
||||
.. autofunction:: datasette.utils.derive_named_parameters
|
||||
.. autofunction:: datasette.utils.named_parameters
|
||||
|
||||
.. _internals_tilde_encoding:
|
||||
|
||||
|
|
|
|||
|
|
@ -612,10 +612,14 @@ def test_parse_metadata(content, expected):
|
|||
("select this is invalid :one, :two, :three", ["one", "two", "three"]),
|
||||
),
|
||||
)
|
||||
async def test_derive_named_parameters(sql, expected):
|
||||
@pytest.mark.parametrize("use_async_version", (False, True))
|
||||
async def test_named_parameters(sql, expected, use_async_version):
|
||||
ds = Datasette([], memory=True)
|
||||
db = ds.get_database("_memory")
|
||||
params = await utils.derive_named_parameters(db, sql)
|
||||
if use_async_version:
|
||||
params = await utils.derive_named_parameters(db, sql)
|
||||
else:
|
||||
params = utils.named_parameters(sql)
|
||||
assert params == expected
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue