mirror of
https://github.com/simonw/datasette.git
synced 2025-12-10 16:51:24 +01:00
New way of deriving named parameters using explain, refs #1421
This commit is contained in:
parent
ad90a72afa
commit
fc4846850f
4 changed files with 31 additions and 2 deletions
|
|
@ -1076,3 +1076,15 @@ class PrefixedUrlString(str):
|
||||||
|
|
||||||
class StartupError(Exception):
|
class StartupError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
|
||||||
|
|
||||||
|
async def derive_named_parameters(db, sql):
|
||||||
|
explain = 'explain {}'.format(sql.strip().rstrip(";"))
|
||||||
|
possible_params = _re_named_parameter.findall(sql)
|
||||||
|
try:
|
||||||
|
results = await db.execute(explain, {p: None for p in possible_params})
|
||||||
|
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"]
|
||||||
|
except sqlite3.DatabaseError:
|
||||||
|
return []
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,6 @@ class BaseView:
|
||||||
|
|
||||||
class DataView(BaseView):
|
class DataView(BaseView):
|
||||||
name = ""
|
name = ""
|
||||||
re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
|
|
||||||
|
|
||||||
async def options(self, request, *args, **kwargs):
|
async def options(self, request, *args, **kwargs):
|
||||||
r = Response.text("ok")
|
r = Response.text("ok")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import markupsafe
|
||||||
from datasette.utils import (
|
from datasette.utils import (
|
||||||
await_me_maybe,
|
await_me_maybe,
|
||||||
check_visibility,
|
check_visibility,
|
||||||
|
derive_named_parameters,
|
||||||
to_css_class,
|
to_css_class,
|
||||||
validate_sql_select,
|
validate_sql_select,
|
||||||
is_url,
|
is_url,
|
||||||
|
|
@ -223,7 +224,9 @@ class QueryView(DataView):
|
||||||
await self.check_permission(request, "execute-sql", database)
|
await self.check_permission(request, "execute-sql", database)
|
||||||
|
|
||||||
# Extract any :named parameters
|
# Extract any :named parameters
|
||||||
named_parameters = named_parameters or self.re_named_parameter.findall(sql)
|
named_parameters = named_parameters or await derive_named_parameters(
|
||||||
|
self.ds.get_database(database), sql
|
||||||
|
)
|
||||||
named_parameter_values = {
|
named_parameter_values = {
|
||||||
named_parameter: params.get(named_parameter) or ""
|
named_parameter: params.get(named_parameter) or ""
|
||||||
for named_parameter in named_parameters
|
for named_parameter in named_parameters
|
||||||
|
|
|
||||||
|
|
@ -626,3 +626,18 @@ def test_parse_metadata(content, expected):
|
||||||
utils.parse_metadata(content)
|
utils.parse_metadata(content)
|
||||||
else:
|
else:
|
||||||
assert utils.parse_metadata(content) == expected
|
assert utils.parse_metadata(content) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sql,expected", (
|
||||||
|
("select 1", []),
|
||||||
|
("select 1 + :one", ["one"]),
|
||||||
|
("select 1 + :one + :two", ["one", "two"]),
|
||||||
|
("select 'bob' || '0:00' || :cat", ["cat"]),
|
||||||
|
("select this is invalid", []),
|
||||||
|
))
|
||||||
|
async def test_derive_named_parameters(sql, expected):
|
||||||
|
ds = Datasette([], memory=True)
|
||||||
|
db = ds.get_database("_memory")
|
||||||
|
params = await utils.derive_named_parameters(db, sql)
|
||||||
|
assert params == expected
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue