diff --git a/datasette/app.py b/datasette/app.py index ef3fde93..8e7cfa07 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -23,7 +23,7 @@ from .utils import ( compound_keys_after_sql, detect_fts_sql, escape_css_string, - escape_sqlite_table_name, + escape_sqlite, filters_should_redirect, get_all_foreign_keys, is_url, @@ -437,7 +437,7 @@ class RowTableShared(BaseView): sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( other_column=fk['other_column'], label_column=label_column, - other_table=escape_sqlite_table_name(fk['other_table']), + other_table=escape_sqlite(fk['other_table']), placeholders=', '.join(['?'] * len(ids_to_lookup)), ) try: @@ -611,18 +611,18 @@ class TableView(RowTableShared): count_sql = None sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count), - table_name=escape_sqlite_table_name(table), + table_name=escape_sqlite(table), where=where_clause, ) is_view = True else: count_sql = 'select count(*) from {table_name} {where}'.format( - table_name=escape_sqlite_table_name(table), + table_name=escape_sqlite(table), where=where_clause, ) sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format( select=select, - table_name=escape_sqlite_table_name(table), + table_name=escape_sqlite(table), where=where_clause, order_by=order_by, limit=self.page_size + 1, @@ -804,7 +804,7 @@ class RowView(RowTableShared): foreign_keys = table_info['foreign_keys']['incoming'] sql = 'select ' + ', '.join([ '(select count(*) from {table} where "{column}"=:id)'.format( - table=escape_sqlite_table_name(fk['other_table']), + table=escape_sqlite(fk['other_table']), column=fk['other_column'], ) for fk in foreign_keys @@ -937,7 +937,7 @@ class Datasette: for table in table_names: try: count = conn.execute( - 'select count(*) from {}'.format(escape_sqlite_table_name(table)) + 'select count(*) from {}'.format(escape_sqlite(table)) ).fetchone()[0] except sqlite3.OperationalError: # This can happen when running against a FTS virtual tables @@ -946,7 +946,7 @@ class Datasette: label_column = None # If table has two columns, one of which is ID, then label_column is the other one column_names = [r[1] for r in conn.execute( - 'PRAGMA table_info({});'.format(escape_sqlite_table_name(table)) + 'PRAGMA table_info({});'.format(escape_sqlite(table)) ).fetchall()] if column_names and len(column_names) == 2 and 'id' in column_names: label_column = [c for c in column_names if c != 'id'][0] @@ -1007,7 +1007,7 @@ class Datasette: ) self.jinja_env.filters['escape_css_string'] = escape_css_string self.jinja_env.filters['quote_plus'] = lambda u: urllib.parse.quote_plus(u) - self.jinja_env.filters['escape_table_name'] = escape_sqlite_table_name + self.jinja_env.filters['escape_sqlite'] = escape_sqlite self.jinja_env.filters['to_css_class'] = to_css_class app.add_route(IndexView.as_view(self), '/') # TODO: /favicon.ico and /-/static/ deserve far-future cache expires diff --git a/datasette/templates/database.html b/datasette/templates/database.html index 8867578a..1d404552 100644 --- a/datasette/templates/database.html +++ b/datasette/templates/database.html @@ -18,7 +18,7 @@

Custom SQL query

-

+

diff --git a/datasette/templates/query.html b/datasette/templates/query.html index 2eee5b12..fa850736 100644 --- a/datasette/templates/query.html +++ b/datasette/templates/query.html @@ -26,7 +26,7 @@

Custom SQL query{% if rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(rows|length) }} row{% if rows|length == 1 %}{% else %}s{% endif %}{% endif %}

{% if editable %} -

+

{% else %}
{% if query %}{{ query.sql }}{% endif %}
{% endif %} diff --git a/datasette/utils.py b/datasette/utils.py index fb132f08..8a185957 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -12,6 +12,23 @@ import shutil import urllib +# From https://www.sqlite.org/lang_keywords.html +reserved_words = set(( + 'abort action add after all alter analyze and as asc attach autoincrement ' + 'before begin between by cascade case cast check collate column commit ' + 'conflict constraint create cross current_date current_time ' + 'current_timestamp database default deferrable deferred delete desc detach ' + 'distinct drop each else end escape except exclusive exists explain fail ' + 'for foreign from full glob group having if ignore immediate in index ' + 'indexed initially inner insert instead intersect into is isnull join key ' + 'left like limit match natural no not notnull null of offset on or order ' + 'outer plan pragma primary query raise recursive references regexp reindex ' + 'release rename replace restrict right rollback row savepoint select set ' + 'table temp temporary then to transaction trigger union unique update using ' + 'vacuum values view virtual when where with without' +).split()) + + def compound_pks_from_path(path): return [ urllib.parse.unquote_plus(b) for b in path.split(',') @@ -45,11 +62,11 @@ def compound_keys_after_sql(pks, start_index=0): and_clauses = [] last = pks_left[-1] rest = pks_left[:-1] - and_clauses = ['[{}] = :p{}'.format( - pk, (i + start_index) + and_clauses = ['{} = :p{}'.format( + escape_sqlite(pk), (i + start_index) ) for i, pk in enumerate(rest)] - and_clauses.append('[{}] > :p{}'.format( - last, (len(rest) + start_index) + and_clauses.append('{} > :p{}'.format( + escape_sqlite(last), (len(rest) + start_index) )) or_clauses.append('({})'.format(' and '.join(and_clauses))) pks_left.pop() @@ -146,15 +163,15 @@ def path_with_ext(request, ext): _css_re = re.compile(r'''['"\n\\]''') -_boring_table_name_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') +_boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') def escape_css_string(s): return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s) -def escape_sqlite_table_name(s): - if _boring_table_name_re.match(s): +def escape_sqlite(s): + if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): return s else: return '[{}]'.format(s) diff --git a/tests/fixtures.py b/tests/fixtures.py index 3d22c289..0c79480a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -116,6 +116,13 @@ CREATE TABLE "complex_foreign_keys" ( FOREIGN KEY ("f3") REFERENCES [simple_primary_key](id) ); +CREATE TABLE [select] ( + [group] text, + [having] text, + [and] text +); +INSERT INTO [select] VALUES ('group', 'having', 'and'); + INSERT INTO simple_primary_key VALUES (1, 'hello'); INSERT INTO simple_primary_key VALUES (2, 'world'); INSERT INTO simple_primary_key VALUES (3, ''); diff --git a/tests/test_api.py b/tests/test_api.py index cff7dadc..de3cc0c6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,7 +13,7 @@ def test_homepage(app_client): assert response.json.keys() == {'test_tables': 0}.keys() d = response.json['test_tables'] assert d['name'] == 'test_tables' - assert d['tables_count'] == 8 + assert d['tables_count'] == 9 def test_database_page(app_client): @@ -77,6 +77,13 @@ def test_database_page(app_client): 'hidden': False, 'foreign_keys': {'incoming': [], 'outgoing': []}, 'label_column': None, + }, { + 'columns': ['group', 'having', 'and'], + 'name': 'select', + 'count': 1, + 'hidden': False, + 'foreign_keys': {'incoming': [], 'outgoing': []}, + 'label_column': None, }, { 'columns': ['pk', 'content'], 'name': 'simple_primary_key', @@ -190,6 +197,18 @@ def test_table_with_slashes_in_name(app_client): }] +def test_table_with_reserved_word_name(app_client): + response = app_client.get('/test_tables/select.jsono', gather_request=False) + assert response.status == 200 + data = response.json + assert data['rows'] == [{ + 'rowid': 1, + 'group': 'group', + 'having': 'having', + 'and': 'and', + }] + + @pytest.mark.parametrize('path,expected_rows,expected_pages', [ ('/test_tables/no_primary_key.jsono', 201, 5), ('/test_tables/paginated_view.jsono', 201, 5), diff --git a/tests/test_utils.py b/tests/test_utils.py index 73476d7c..a5204256 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -227,16 +227,16 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): def test_compound_keys_after_sql(): - assert '(([a] > :p0))' == utils.compound_keys_after_sql(['a']) + assert '((a > :p0))' == utils.compound_keys_after_sql(['a']) assert ''' -(([a] > :p0) +((a > :p0) or -([a] = :p0 and [b] > :p1)) +(a = :p0 and b > :p1)) '''.strip() == utils.compound_keys_after_sql(['a', 'b']) assert ''' -(([a] > :p0) +((a > :p0) or -([a] = :p0 and [b] > :p1) +(a = :p0 and b > :p1) or -([a] = :p0 and [b] = :p1 and [c] > :p2)) +(a = :p0 and b = :p1 and c > :p2)) '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])