escape_sqlite_table_name => escape_sqlite, handles reserved words

It can be used for column names as well as table names.

Reserved word list from https://www.sqlite.org/lang_keywords.html
This commit is contained in:
Simon Willison 2018-04-03 06:39:50 -07:00
commit 8f0d44d646
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
7 changed files with 68 additions and 25 deletions

View file

@ -23,7 +23,7 @@ from .utils import (
compound_keys_after_sql, compound_keys_after_sql,
detect_fts_sql, detect_fts_sql,
escape_css_string, escape_css_string,
escape_sqlite_table_name, escape_sqlite,
filters_should_redirect, filters_should_redirect,
get_all_foreign_keys, get_all_foreign_keys,
is_url, is_url,
@ -437,7 +437,7 @@ class RowTableShared(BaseView):
sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format(
other_column=fk['other_column'], other_column=fk['other_column'],
label_column=label_column, label_column=label_column,
other_table=escape_sqlite_table_name(fk['other_table']), other_table=escape_sqlite(fk['other_table']),
placeholders=', '.join(['?'] * len(ids_to_lookup)), placeholders=', '.join(['?'] * len(ids_to_lookup)),
) )
try: try:
@ -611,18 +611,18 @@ class TableView(RowTableShared):
count_sql = None count_sql = None
sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format(
group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count), group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count),
table_name=escape_sqlite_table_name(table), table_name=escape_sqlite(table),
where=where_clause, where=where_clause,
) )
is_view = True is_view = True
else: else:
count_sql = 'select count(*) from {table_name} {where}'.format( count_sql = 'select count(*) from {table_name} {where}'.format(
table_name=escape_sqlite_table_name(table), table_name=escape_sqlite(table),
where=where_clause, where=where_clause,
) )
sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format( sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format(
select=select, select=select,
table_name=escape_sqlite_table_name(table), table_name=escape_sqlite(table),
where=where_clause, where=where_clause,
order_by=order_by, order_by=order_by,
limit=self.page_size + 1, limit=self.page_size + 1,
@ -804,7 +804,7 @@ class RowView(RowTableShared):
foreign_keys = table_info['foreign_keys']['incoming'] foreign_keys = table_info['foreign_keys']['incoming']
sql = 'select ' + ', '.join([ sql = 'select ' + ', '.join([
'(select count(*) from {table} where "{column}"=:id)'.format( '(select count(*) from {table} where "{column}"=:id)'.format(
table=escape_sqlite_table_name(fk['other_table']), table=escape_sqlite(fk['other_table']),
column=fk['other_column'], column=fk['other_column'],
) )
for fk in foreign_keys for fk in foreign_keys
@ -937,7 +937,7 @@ class Datasette:
for table in table_names: for table in table_names:
try: try:
count = conn.execute( count = conn.execute(
'select count(*) from {}'.format(escape_sqlite_table_name(table)) 'select count(*) from {}'.format(escape_sqlite(table))
).fetchone()[0] ).fetchone()[0]
except sqlite3.OperationalError: except sqlite3.OperationalError:
# This can happen when running against a FTS virtual tables # This can happen when running against a FTS virtual tables
@ -946,7 +946,7 @@ class Datasette:
label_column = None label_column = None
# If table has two columns, one of which is ID, then label_column is the other one # If table has two columns, one of which is ID, then label_column is the other one
column_names = [r[1] for r in conn.execute( column_names = [r[1] for r in conn.execute(
'PRAGMA table_info({});'.format(escape_sqlite_table_name(table)) 'PRAGMA table_info({});'.format(escape_sqlite(table))
).fetchall()] ).fetchall()]
if column_names and len(column_names) == 2 and 'id' in column_names: if column_names and len(column_names) == 2 and 'id' in column_names:
label_column = [c for c in column_names if c != 'id'][0] label_column = [c for c in column_names if c != 'id'][0]
@ -1007,7 +1007,7 @@ class Datasette:
) )
self.jinja_env.filters['escape_css_string'] = escape_css_string self.jinja_env.filters['escape_css_string'] = escape_css_string
self.jinja_env.filters['quote_plus'] = lambda u: urllib.parse.quote_plus(u) self.jinja_env.filters['quote_plus'] = lambda u: urllib.parse.quote_plus(u)
self.jinja_env.filters['escape_table_name'] = escape_sqlite_table_name self.jinja_env.filters['escape_sqlite'] = escape_sqlite
self.jinja_env.filters['to_css_class'] = to_css_class self.jinja_env.filters['to_css_class'] = to_css_class
app.add_route(IndexView.as_view(self), '/<as_json:(\.jsono?)?$>') app.add_route(IndexView.as_view(self), '/<as_json:(\.jsono?)?$>')
# TODO: /favicon.ico and /-/static/ deserve far-future cache expires # TODO: /favicon.ico and /-/static/ deserve far-future cache expires

View file

@ -18,7 +18,7 @@
<form class="sql" action="/{{ database }}-{{ database_hash }}" method="get"> <form class="sql" action="/{{ database }}-{{ database_hash }}" method="get">
<h3>Custom SQL query</h3> <h3>Custom SQL query</h3>
<p><textarea name="sql">select * from {{ tables[0].name|escape_table_name }}</textarea></p> <p><textarea name="sql">select * from {{ tables[0].name|escape_sqlite }}</textarea></p>
<p><input type="submit" value="Run SQL"></p> <p><input type="submit" value="Run SQL"></p>
</form> </form>

View file

@ -26,7 +26,7 @@
<form class="sql" action="/{{ database }}-{{ database_hash }}{% if canned_query %}/{{ canned_query }}{% endif %}" method="get"> <form class="sql" action="/{{ database }}-{{ database_hash }}{% if canned_query %}/{{ canned_query }}{% endif %}" method="get">
<h3>Custom SQL query{% if rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(rows|length) }} row{% if rows|length == 1 %}{% else %}s{% endif %}{% endif %}</h3> <h3>Custom SQL query{% if rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(rows|length) }} row{% if rows|length == 1 %}{% else %}s{% endif %}{% endif %}</h3>
{% if editable %} {% if editable %}
<p><textarea name="sql">{% if query and query.sql %}{{ query.sql }}{% else %}select * from {{ tables[0].name|escape_table_name }}{% endif %}</textarea></p> <p><textarea name="sql">{% if query and query.sql %}{{ query.sql }}{% else %}select * from {{ tables[0].name|escape_sqlite }}{% endif %}</textarea></p>
{% else %} {% else %}
<pre>{% if query %}{{ query.sql }}{% endif %}</pre> <pre>{% if query %}{{ query.sql }}{% endif %}</pre>
{% endif %} {% endif %}

View file

@ -12,6 +12,23 @@ import shutil
import urllib import urllib
# From https://www.sqlite.org/lang_keywords.html
reserved_words = set((
'abort action add after all alter analyze and as asc attach autoincrement '
'before begin between by cascade case cast check collate column commit '
'conflict constraint create cross current_date current_time '
'current_timestamp database default deferrable deferred delete desc detach '
'distinct drop each else end escape except exclusive exists explain fail '
'for foreign from full glob group having if ignore immediate in index '
'indexed initially inner insert instead intersect into is isnull join key '
'left like limit match natural no not notnull null of offset on or order '
'outer plan pragma primary query raise recursive references regexp reindex '
'release rename replace restrict right rollback row savepoint select set '
'table temp temporary then to transaction trigger union unique update using '
'vacuum values view virtual when where with without'
).split())
def compound_pks_from_path(path): def compound_pks_from_path(path):
return [ return [
urllib.parse.unquote_plus(b) for b in path.split(',') urllib.parse.unquote_plus(b) for b in path.split(',')
@ -45,11 +62,11 @@ def compound_keys_after_sql(pks, start_index=0):
and_clauses = [] and_clauses = []
last = pks_left[-1] last = pks_left[-1]
rest = pks_left[:-1] rest = pks_left[:-1]
and_clauses = ['[{}] = :p{}'.format( and_clauses = ['{} = :p{}'.format(
pk, (i + start_index) escape_sqlite(pk), (i + start_index)
) for i, pk in enumerate(rest)] ) for i, pk in enumerate(rest)]
and_clauses.append('[{}] > :p{}'.format( and_clauses.append('{} > :p{}'.format(
last, (len(rest) + start_index) escape_sqlite(last), (len(rest) + start_index)
)) ))
or_clauses.append('({})'.format(' and '.join(and_clauses))) or_clauses.append('({})'.format(' and '.join(and_clauses)))
pks_left.pop() pks_left.pop()
@ -146,15 +163,15 @@ def path_with_ext(request, ext):
_css_re = re.compile(r'''['"\n\\]''') _css_re = re.compile(r'''['"\n\\]''')
_boring_table_name_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') _boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
def escape_css_string(s): def escape_css_string(s):
return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s) return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s)
def escape_sqlite_table_name(s): def escape_sqlite(s):
if _boring_table_name_re.match(s): if _boring_keyword_re.match(s) and (s.lower() not in reserved_words):
return s return s
else: else:
return '[{}]'.format(s) return '[{}]'.format(s)

View file

@ -116,6 +116,13 @@ CREATE TABLE "complex_foreign_keys" (
FOREIGN KEY ("f3") REFERENCES [simple_primary_key](id) FOREIGN KEY ("f3") REFERENCES [simple_primary_key](id)
); );
CREATE TABLE [select] (
[group] text,
[having] text,
[and] text
);
INSERT INTO [select] VALUES ('group', 'having', 'and');
INSERT INTO simple_primary_key VALUES (1, 'hello'); INSERT INTO simple_primary_key VALUES (1, 'hello');
INSERT INTO simple_primary_key VALUES (2, 'world'); INSERT INTO simple_primary_key VALUES (2, 'world');
INSERT INTO simple_primary_key VALUES (3, ''); INSERT INTO simple_primary_key VALUES (3, '');

View file

@ -13,7 +13,7 @@ def test_homepage(app_client):
assert response.json.keys() == {'test_tables': 0}.keys() assert response.json.keys() == {'test_tables': 0}.keys()
d = response.json['test_tables'] d = response.json['test_tables']
assert d['name'] == 'test_tables' assert d['name'] == 'test_tables'
assert d['tables_count'] == 8 assert d['tables_count'] == 9
def test_database_page(app_client): def test_database_page(app_client):
@ -77,6 +77,13 @@ def test_database_page(app_client):
'hidden': False, 'hidden': False,
'foreign_keys': {'incoming': [], 'outgoing': []}, 'foreign_keys': {'incoming': [], 'outgoing': []},
'label_column': None, 'label_column': None,
}, {
'columns': ['group', 'having', 'and'],
'name': 'select',
'count': 1,
'hidden': False,
'foreign_keys': {'incoming': [], 'outgoing': []},
'label_column': None,
}, { }, {
'columns': ['pk', 'content'], 'columns': ['pk', 'content'],
'name': 'simple_primary_key', 'name': 'simple_primary_key',
@ -190,6 +197,18 @@ def test_table_with_slashes_in_name(app_client):
}] }]
def test_table_with_reserved_word_name(app_client):
response = app_client.get('/test_tables/select.jsono', gather_request=False)
assert response.status == 200
data = response.json
assert data['rows'] == [{
'rowid': 1,
'group': 'group',
'having': 'having',
'and': 'and',
}]
@pytest.mark.parametrize('path,expected_rows,expected_pages', [ @pytest.mark.parametrize('path,expected_rows,expected_pages', [
('/test_tables/no_primary_key.jsono', 201, 5), ('/test_tables/no_primary_key.jsono', 201, 5),
('/test_tables/paginated_view.jsono', 201, 5), ('/test_tables/paginated_view.jsono', 201, 5),

View file

@ -227,16 +227,16 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link):
def test_compound_keys_after_sql(): def test_compound_keys_after_sql():
assert '(([a] > :p0))' == utils.compound_keys_after_sql(['a']) assert '((a > :p0))' == utils.compound_keys_after_sql(['a'])
assert ''' assert '''
(([a] > :p0) ((a > :p0)
or or
([a] = :p0 and [b] > :p1)) (a = :p0 and b > :p1))
'''.strip() == utils.compound_keys_after_sql(['a', 'b']) '''.strip() == utils.compound_keys_after_sql(['a', 'b'])
assert ''' assert '''
(([a] > :p0) ((a > :p0)
or or
([a] = :p0 and [b] > :p1) (a = :p0 and b > :p1)
or or
([a] = :p0 and [b] = :p1 and [c] > :p2)) (a = :p0 and b = :p1 and c > :p2))
'''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])