diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..0cece53b --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +multi_line_output=3 + diff --git a/datasette/app.py b/datasette/app.py index beb7e924..a37a4a45 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,449 +1,52 @@ -from sanic import Sanic -from sanic import response -from sanic.exceptions import NotFound, InvalidUsage -from sanic.views import HTTPMethodView -from sanic.request import RequestParameters -from jinja2 import Environment, FileSystemLoader, ChoiceLoader, PrefixLoader -import re -import sqlite3 -from pathlib import Path -from concurrent import futures -import asyncio -import os -import threading -import urllib.parse +import hashlib import itertools import json -import jinja2 -import hashlib +import os +import sqlite3 import sys -import time -import pint -import pluggy import traceback +import urllib.parse +from concurrent import futures +from pathlib import Path + +import pluggy +from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader +from sanic import Sanic, response +from sanic.exceptions import InvalidUsage, NotFound + +from datasette.views.base import ( + HASH_BLOCK_SIZE, + DatasetteError, + RenderMixin, + ureg +) +from datasette.views.database import DatabaseDownload, DatabaseView +from datasette.views.index import IndexView +from datasette.views.table import RowView, TableView + +from . import hookspecs from .utils import ( - Filters, - CustomJSONEncoder, - compound_keys_after_sql, detect_fts, detect_spatialite, escape_css_string, escape_sqlite, - filters_should_redirect, get_all_foreign_keys, get_plugins, - is_url, - InvalidSql, module_from_path, - path_from_row_pks, - path_with_added_args, - path_with_ext, - sqlite_timelimit, - to_css_class, - urlsafe_components, - validate_sql_select, + to_css_class ) -from . import __version__ -from . import hookspecs from .version import __version__ app_root = Path(__file__).parent.parent -HASH_BLOCK_SIZE = 1024 * 1024 -HASH_LENGTH = 7 -connections = threading.local() -ureg = pint.UnitRegistry() - - -pm = pluggy.PluginManager('datasette') +pm = pluggy.PluginManager("datasette") pm.add_hookspecs(hookspecs) -pm.load_setuptools_entrypoints('datasette') - - -class DatasetteError(Exception): - def __init__(self, message, title=None, error_dict=None, status=500, template=None): - self.message = message - self.title = title - self.error_dict = error_dict or {} - self.status = status - - -class RenderMixin(HTTPMethodView): - def render(self, templates, **context): - template = self.jinja_env.select_template(templates) - select_templates = ['{}{}'.format( - '*' if template_name == template.name else '', - template_name - ) for template_name in templates] - return response.html( - template.render({ - **context, **{ - 'app_css_hash': self.ds.app_css_hash(), - 'select_templates': select_templates, - 'zip': zip, - } - }) - ) - - -class BaseView(RenderMixin): - re_named_parameter = re.compile(':([a-zA-Z0-9_]+)') - - def __init__(self, datasette): - self.ds = datasette - self.files = datasette.files - self.jinja_env = datasette.jinja_env - self.executor = datasette.executor - self.page_size = datasette.page_size - self.max_returned_rows = datasette.max_returned_rows - - def table_metadata(self, database, table): - "Fetch table-specific metadata." - return self.ds.metadata.get( - 'databases', {} - ).get(database, {}).get('tables', {}).get(table, {}) - - def options(self, request, *args, **kwargs): - r = response.text('ok') - if self.ds.cors: - r.headers['Access-Control-Allow-Origin'] = '*' - return r - - def redirect(self, request, path, forward_querystring=True): - if request.query_string and '?' not in path and forward_querystring: - path = '{}?{}'.format( - path, request.query_string - ) - r = response.redirect(path) - r.headers['Link'] = '<{}>; rel=preload'.format(path) - if self.ds.cors: - r.headers['Access-Control-Allow-Origin'] = '*' - return r - - def resolve_db_name(self, db_name, **kwargs): - databases = self.ds.inspect() - hash = None - name = None - if '-' in db_name: - # Might be name-and-hash, or might just be - # a name with a hyphen in it - name, hash = db_name.rsplit('-', 1) - if name not in databases: - # Try the whole name - name = db_name - hash = None - else: - name = db_name - # Verify the hash - try: - info = databases[name] - except KeyError: - raise NotFound('Database not found: {}'.format(name)) - expected = info['hash'][:HASH_LENGTH] - if expected != hash: - should_redirect = '/{}-{}'.format( - name, expected, - ) - if 'table' in kwargs: - should_redirect += '/' + kwargs['table'] - if 'pk_path' in kwargs: - should_redirect += '/' + kwargs['pk_path'] - if 'as_json' in kwargs: - should_redirect += kwargs['as_json'] - if 'as_db' in kwargs: - should_redirect += kwargs['as_db'] - return name, expected, should_redirect - return name, expected, None - - async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.ds.inspect()[db_name] - conn = sqlite3.connect( - 'file:{}?immutable=1'.format(info['file']), - uri=True, - check_same_thread=False, - ) - self.ds.prepare_connection(conn) - setattr(connections, db_name, conn) - - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params or {}) - max_returned_rows = self.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except Exception as e: - print('ERROR: conn={}, sql = {}, params = {}: {}'.format( - conn, repr(sql), params, e - )) - raise - if truncate: - return rows, truncated, cursor.description - else: - return rows - - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - - def get_templates(self, database, table=None): - assert NotImplemented - - async def get(self, request, db_name, **kwargs): - name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) - if should_redirect: - return self.redirect(request, should_redirect) - return await self.view_get(request, name, hash, **kwargs) - - async def view_get(self, request, name, hash, **kwargs): - try: - as_json = kwargs.pop('as_json') - except KeyError: - as_json = False - extra_template_data = {} - start = time.time() - status_code = 200 - templates = [] - try: - response_or_template_contexts = await self.data( - request, name, hash, **kwargs - ) - if isinstance(response_or_template_contexts, response.HTTPResponse): - return response_or_template_contexts - else: - data, extra_template_data, templates = response_or_template_contexts - except (sqlite3.OperationalError, InvalidSql) as e: - raise DatasetteError(str(e), title='Invalid SQL', status=400) - except (sqlite3.OperationalError) as e: - raise DatasetteError(str(e)) - except DatasetteError: - raise - end = time.time() - data['query_ms'] = (end - start) * 1000 - for key in ('source', 'source_url', 'license', 'license_url'): - value = self.ds.metadata.get(key) - if value: - data[key] = value - if as_json: - # Special case for .jsono extension - redirect to _shape=objects - if as_json == '.jsono': - return self.redirect( - request, - path_with_added_args( - request, - {'_shape': 'objects'}, - path=request.path.rsplit('.jsono', 1)[0] + '.json' - ), - forward_querystring=False - ) - # Deal with the _shape option - shape = request.args.get('_shape', 'arrays') - if shape in ('objects', 'object', 'array'): - columns = data.get('columns') - rows = data.get('rows') - if rows and columns: - data['rows'] = [ - dict(zip(columns, row)) - for row in rows - ] - if shape == 'object': - error = None - if 'primary_keys' not in data: - error = '_shape=object is only available on tables' - else: - pks = data['primary_keys'] - if not pks: - error = '_shape=object not available for tables with no primary keys' - else: - object_rows = {} - for row in data['rows']: - pk_string = path_from_row_pks(row, pks, not pks) - object_rows[pk_string] = row - data = object_rows - if error: - data = { - 'ok': False, - 'error': error, - 'database': name, - 'database_hash': hash, - } - elif shape == 'array': - data = data['rows'] - elif shape == 'arrays': - pass - else: - status_code = 400 - data = { - 'ok': False, - 'error': 'Invalid _shape: {}'.format(shape), - 'status': 400, - 'title': None, - } - headers = {} - if self.ds.cors: - headers['Access-Control-Allow-Origin'] = '*' - r = response.HTTPResponse( - json.dumps( - data, cls=CustomJSONEncoder - ), - status=status_code, - content_type='application/json', - headers=headers, - ) - else: - extras = {} - if callable(extra_template_data): - extras = extra_template_data() - if asyncio.iscoroutine(extras): - extras = await extras - else: - extras = extra_template_data - context = { - **data, - **extras, - **{ - 'url_json': path_with_ext(request, '.json'), - 'url_jsono': path_with_ext(request, '.jsono'), - 'extra_css_urls': self.ds.extra_css_urls(), - 'extra_js_urls': self.ds.extra_js_urls(), - 'datasette_version': __version__, - } - } - if 'metadata' not in context: - context['metadata'] = self.ds.metadata - r = self.render( - templates, - **context, - ) - r.status = status_code - # Set far-future cache expiry - if self.ds.cache_headers: - r.headers['Cache-Control'] = 'max-age={}'.format( - 365 * 24 * 60 * 60 - ) - return r - - async def custom_sql(self, request, name, hash, sql, editable=True, canned_query=None): - params = request.raw_args - if 'sql' in params: - params.pop('sql') - if '_shape' in params: - params.pop('_shape') - # Extract any :named parameters - named_parameters = self.re_named_parameter.findall(sql) - named_parameter_values = { - named_parameter: params.get(named_parameter) or '' - for named_parameter in named_parameters - } - - # Set to blank string if missing from params - for named_parameter in named_parameters: - if named_parameter not in params: - params[named_parameter] = '' - - extra_args = {} - if params.get('_timelimit'): - extra_args['custom_time_limit'] = int(params['_timelimit']) - rows, truncated, description = await self.execute( - name, sql, params, truncate=True, **extra_args - ) - columns = [r[0] for r in description] - - templates = ['query-{}.html'.format(to_css_class(name)), 'query.html'] - if canned_query: - templates.insert(0, 'query-{}-{}.html'.format( - to_css_class(name), to_css_class(canned_query) - )) - - return { - 'database': name, - 'rows': rows, - 'truncated': truncated, - 'columns': columns, - 'query': { - 'sql': sql, - 'params': params, - } - }, { - 'database_hash': hash, - 'custom_sql': True, - 'named_parameter_values': named_parameter_values, - 'editable': editable, - 'canned_query': canned_query, - }, templates - - -class IndexView(RenderMixin): - def __init__(self, datasette): - self.ds = datasette - self.files = datasette.files - self.jinja_env = datasette.jinja_env - self.executor = datasette.executor - - async def get(self, request, as_json): - databases = [] - for key, info in sorted(self.ds.inspect().items()): - tables = [t for t in info['tables'].values() if not t['hidden']] - hidden_tables = [t for t in info['tables'].values() if t['hidden']] - database = { - 'name': key, - 'hash': info['hash'], - 'path': '{}-{}'.format(key, info['hash'][:HASH_LENGTH]), - 'tables_truncated': sorted( - tables, - key=lambda t: t['count'], - reverse=True - )[:5], - 'tables_count': len(tables), - 'tables_more': len(tables) > 5, - 'table_rows_sum': sum(t['count'] for t in tables), - 'hidden_table_rows_sum': sum(t['count'] for t in hidden_tables), - 'hidden_tables_count': len(hidden_tables), - 'views_count': len(info['views']), - } - databases.append(database) - if as_json: - headers = {} - if self.ds.cors: - headers['Access-Control-Allow-Origin'] = '*' - return response.HTTPResponse( - json.dumps( - {db['name']: db for db in databases}, - cls=CustomJSONEncoder - ), - content_type='application/json', - headers=headers, - ) - else: - return self.render( - ['index.html'], - databases=databases, - metadata=self.ds.metadata, - datasette_version=__version__, - extra_css_urls=self.ds.extra_css_urls(), - extra_js_urls=self.ds.extra_js_urls(), - ) +pm.load_setuptools_entrypoints("datasette") class JsonDataView(RenderMixin): + def __init__(self, datasette, filename, data_callback): self.ds = datasette self.jinja_env = datasette.jinja_env @@ -455,751 +58,40 @@ class JsonDataView(RenderMixin): if as_json: headers = {} if self.ds.cors: - headers['Access-Control-Allow-Origin'] = '*' + headers["Access-Control-Allow-Origin"] = "*" return response.HTTPResponse( - json.dumps(data), - content_type='application/json', - headers=headers, + json.dumps(data), content_type="application/json", headers=headers ) + else: - return self.render( - ['show_json.html'], - filename=self.filename, - data=data, - ) + return self.render(["show_json.html"], filename=self.filename, data=data) async def favicon(request): - return response.text('') - - -class DatabaseView(BaseView): - async def data(self, request, name, hash): - if request.args.get('sql'): - sql = request.raw_args.pop('sql') - validate_sql_select(sql) - return await self.custom_sql(request, name, hash, sql) - info = self.ds.inspect()[name] - metadata = self.ds.metadata.get('databases', {}).get(name, {}) - self.ds.update_with_inherited_metadata(metadata) - tables = list(info['tables'].values()) - tables.sort(key=lambda t: (t['hidden'], t['name'])) - return { - 'database': name, - 'tables': tables, - 'hidden_count': len([t for t in tables if t['hidden']]), - 'views': info['views'], - 'queries': [{ - 'name': query_name, - 'sql': query_sql, - } for query_name, query_sql in (metadata.get('queries') or {}).items()], - }, { - 'database_hash': hash, - 'show_hidden': request.args.get('_show_hidden'), - 'editable': True, - 'metadata': metadata, - }, ('database-{}.html'.format(to_css_class(name)), 'database.html') - - -class DatabaseDownload(BaseView): - async def view_get(self, request, name, hash, **kwargs): - filepath = self.ds.inspect()[name]['file'] - return await response.file_stream( - filepath, - filename=os.path.basename(filepath), - mime_type='application/octet-stream', - ) - - -class RowTableShared(BaseView): - def sortable_columns_for_table(self, name, table, use_rowid): - table_metadata = self.table_metadata(name, table) - if 'sortable_columns' in table_metadata: - sortable_columns = set(table_metadata['sortable_columns']) - else: - table_info = self.ds.inspect()[name]['tables'].get(table) or {} - sortable_columns = set(table_info.get('columns', [])) - if use_rowid: - sortable_columns.add('rowid') - return sortable_columns - - async def display_columns_and_rows(self, database, table, description, rows, link_column=False, expand_foreign_keys=True): - "Returns columns, rows for specified table - including fancy foreign key treatment" - table_metadata = self.table_metadata(database, table) - info = self.ds.inspect()[database] - sortable_columns = self.sortable_columns_for_table(database, table, True) - columns = [{ - 'name': r[0], - 'sortable': r[0] in sortable_columns, - } for r in description] - tables = info['tables'] - table_info = tables.get(table) or {} - pks = table_info.get('primary_keys') or [] - - # Prefetch foreign key resolutions for later expansion: - fks = {} - labeled_fks = {} - if table_info and expand_foreign_keys: - foreign_keys = table_info['foreign_keys']['outgoing'] - for fk in foreign_keys: - label_column = ( - # First look in metadata.json definition for this foreign key table: - self.table_metadata(database, fk['other_table']).get('label_column') - # Fall back to label_column from .inspect() detection: - or tables.get(fk['other_table'], {}).get('label_column') - ) - if not label_column: - # No label for this FK - fks[fk['column']] = fk['other_table'] - continue - ids_to_lookup = set([row[fk['column']] for row in rows]) - sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( - other_column=fk['other_column'], - label_column=label_column, - other_table=escape_sqlite(fk['other_table']), - placeholders=', '.join(['?'] * len(ids_to_lookup)), - ) - try: - results = await self.execute(database, sql, list(set(ids_to_lookup))) - except sqlite3.OperationalError: - # Probably hit the timelimit - pass - else: - for id, value in results: - labeled_fks[(fk['column'], id)] = (fk['other_table'], value) - - cell_rows = [] - for row in rows: - cells = [] - # Unless we are a view, the first column is a link - either to the rowid - # or to the simple or compound primary key - if link_column: - cells.append({ - 'column': pks[0] if len(pks) == 1 else 'Link', - 'value': jinja2.Markup( - '{flat_pks}'.format( - database=database, - table=urllib.parse.quote_plus(table), - flat_pks=str(jinja2.escape(path_from_row_pks(row, pks, not pks, False))), - flat_pks_quoted=path_from_row_pks(row, pks, not pks) - ) - ), - }) - - for value, column_dict in zip(row, columns): - column = column_dict['name'] - if link_column and len(pks) == 1 and column == pks[0]: - # If there's a simple primary key, don't repeat the value as it's - # already shown in the link column. - continue - if (column, value) in labeled_fks: - other_table, label = labeled_fks[(column, value)] - display_value = jinja2.Markup( - '{label} {id}'.format( - database=database, - table=urllib.parse.quote_plus(other_table), - link_id=urllib.parse.quote_plus(str(value)), - id=str(jinja2.escape(value)), - label=str(jinja2.escape(label)), - ) - ) - elif column in fks: - display_value = jinja2.Markup( - '{id}'.format( - database=database, - table=urllib.parse.quote_plus(fks[column]), - link_id=urllib.parse.quote_plus(str(value)), - id=str(jinja2.escape(value)))) - elif value is None: - display_value = jinja2.Markup(' ') - elif is_url(str(value).strip()): - display_value = jinja2.Markup( - '{url}'.format( - url=jinja2.escape(value.strip()) - ) - ) - elif column in table_metadata.get('units', {}) and value != '': - # Interpret units using pint - value = value * ureg(table_metadata['units'][column]) - # Pint uses floating point which sometimes introduces errors in the compact - # representation, which we have to round off to avoid ugliness. In the vast - # majority of cases this rounding will be inconsequential. I hope. - value = round(value.to_compact(), 6) - display_value = jinja2.Markup('{:~P}'.format(value).replace(' ', ' ')) - else: - display_value = str(value) - - cells.append({ - 'column': column, - 'value': display_value, - }) - cell_rows.append(cells) - - if link_column: - # Add the link column header. - # If it's a simple primary key, we have to remove and re-add that column name at - # the beginning of the header row. - if len(pks) == 1: - columns = [col for col in columns if col['name'] != pks[0]] - - columns = [{ - 'name': pks[0] if len(pks) == 1 else 'Link', - 'sortable': len(pks) == 1, - }] + columns - return columns, cell_rows - - -class TableView(RowTableShared): - async def data(self, request, name, hash, table): - table = urllib.parse.unquote_plus(table) - canned_query = self.ds.get_canned_query(name, table) - if canned_query is not None: - return await self.custom_sql(request, name, hash, canned_query['sql'], editable=False, canned_query=table) - is_view = bool(list(await self.execute( - name, - "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", - {'n': table} - ))[0][0]) - view_definition = None - table_definition = None - if is_view: - view_definition = list(await self.execute( - name, - 'select sql from sqlite_master where name = :n and type="view"', - {'n': table} - ))[0][0] - else: - table_definition_rows = list(await self.execute( - name, - 'select sql from sqlite_master where name = :n and type="table"', - {'n': table} - )) - if not table_definition_rows: - raise NotFound('Table not found: {}'.format(table)) - table_definition = table_definition_rows[0][0] - info = self.ds.inspect() - table_info = info[name]['tables'].get(table) or {} - pks = table_info.get('primary_keys') or [] - use_rowid = not pks and not is_view - if use_rowid: - select = 'rowid, *' - order_by = 'rowid' - order_by_pks = 'rowid' - else: - select = '*' - order_by_pks = ', '.join([escape_sqlite(pk) for pk in pks]) - order_by = order_by_pks - - if is_view: - order_by = '' - - # We roll our own query_string decoder because by default Sanic - # drops anything with an empty value e.g. ?name__exact= - args = RequestParameters( - urllib.parse.parse_qs(request.query_string, keep_blank_values=True) - ) - - # Special args start with _ and do not contain a __ - # That's so if there is a column that starts with _ - # it can still be queried using ?_col__exact=blah - special_args = {} - special_args_lists = {} - other_args = {} - for key, value in args.items(): - if key.startswith('_') and '__' not in key: - special_args[key] = value[0] - special_args_lists[key] = value - else: - other_args[key] = value[0] - - # Handle ?_filter_column and redirect, if present - redirect_params = filters_should_redirect(special_args) - if redirect_params: - return self.redirect( - request, - path_with_added_args(request, redirect_params), - forward_querystring=False - ) - - # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) - if '_sort_by_desc' in special_args: - return self.redirect( - request, - path_with_added_args(request, { - '_sort_desc': special_args.get('_sort'), - '_sort_by_desc': None, - '_sort': None, - }), - forward_querystring=False - ) - - table_metadata = self.table_metadata(name, table) - units = table_metadata.get('units', {}) - filters = Filters(sorted(other_args.items()), units, ureg) - where_clauses, params = filters.build_where_clauses() - - # _search support: - fts_table = info[name]['tables'].get(table, {}).get('fts_table') - search_args = dict( - pair for pair in special_args.items() - if pair[0].startswith('_search') - ) - search_descriptions = [] - search = '' - if fts_table and search_args: - if '_search' in search_args: - # Simple ?_search=xxx - search = search_args['_search'] - where_clauses.append( - 'rowid in (select rowid from [{fts_table}] where [{fts_table}] match :search)'.format( - fts_table=fts_table - ) - ) - search_descriptions.append('search matches "{}"'.format(search)) - params['search'] = search - else: - # More complex: search against specific columns - valid_columns = set(info[name]['tables'][fts_table]['columns']) - for i, (key, search_text) in enumerate(search_args.items()): - search_col = key.split('_search_', 1)[1] - if search_col not in valid_columns: - raise DatasetteError( - 'Cannot search by that column', - status=400 - ) - where_clauses.append( - 'rowid in (select rowid from [{fts_table}] where [{search_col}] match :search_{i})'.format( - fts_table=fts_table, - search_col=search_col, - i=i, - ) - ) - search_descriptions.append( - 'search column "{}" matches "{}"'.format(search_col, search_text) - ) - params['search_{}'.format(i)] = search_text - - table_rows_count = None - sortable_columns = set() - if not is_view: - table_rows_count = table_info['count'] - sortable_columns = self.sortable_columns_for_table(name, table, use_rowid) - - # Allow for custom sort order - sort = special_args.get('_sort') - if sort: - if sort not in sortable_columns: - raise DatasetteError('Cannot sort table by {}'.format(sort)) - order_by = escape_sqlite(sort) - sort_desc = special_args.get('_sort_desc') - if sort_desc: - if sort_desc not in sortable_columns: - raise DatasetteError('Cannot sort table by {}'.format(sort_desc)) - if sort: - raise DatasetteError('Cannot use _sort and _sort_desc at the same time') - order_by = '{} desc'.format(escape_sqlite(sort_desc)) - - from_sql = 'from {table_name} {where}'.format( - table_name=escape_sqlite(table), - where=( - 'where {} '.format(' and '.join(where_clauses)) - ) if where_clauses else '', - ) - count_sql = 'select count(*) {}'.format(from_sql) - - _next = special_args.get('_next') - offset = '' - if _next: - if is_view: - # _next is an offset - offset = ' offset {}'.format(int(_next)) - else: - components = urlsafe_components(_next) - # If a sort order is applied, the first of these is the sort value - if sort or sort_desc: - sort_value = components[0] - # Special case for if non-urlencoded first token was $null - if _next.split(',')[0] == '$null': - sort_value = None - components = components[1:] - - # Figure out the SQL for next-based-on-primary-key first - next_by_pk_clauses = [] - if use_rowid: - next_by_pk_clauses.append( - 'rowid > :p{}'.format( - len(params), - ) - ) - params['p{}'.format(len(params))] = components[0] - else: - # Apply the tie-breaker based on primary keys - if len(components) == len(pks): - param_len = len(params) - next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) - for i, pk_value in enumerate(components): - params['p{}'.format(param_len + i)] = pk_value - - # Now add the sort SQL, which may incorporate next_by_pk_clauses - if sort or sort_desc: - if sort_value is None: - if sort_desc: - # Just items where column is null ordered by pk - where_clauses.append( - '({column} is null and {next_clauses})'.format( - column=escape_sqlite(sort_desc), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - '({column} is not null or ({column} is null and {next_clauses}))'.format( - column=escape_sqlite(sort), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - '({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))'.format( - column=escape_sqlite(sort or sort_desc), - op='>' if sort else '<', - p=len(params), - extra_desc_only='' if sort else ' or {column2} is null'.format( - column2=escape_sqlite(sort or sort_desc), - ), - next_clauses=' and '.join(next_by_pk_clauses), - ) - ) - params['p{}'.format(len(params))] = sort_value - order_by = '{}, {}'.format( - order_by, order_by_pks - ) - else: - where_clauses.extend(next_by_pk_clauses) - - where_clause = '' - if where_clauses: - where_clause = 'where {} '.format(' and '.join(where_clauses)) - - if order_by: - order_by = 'order by {} '.format(order_by) - - # _group_count=col1&_group_count=col2 - group_count = special_args_lists.get('_group_count') or [] - if group_count: - sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( - group_cols=', '.join('"{}"'.format(group_count_col) for group_count_col in group_count), - table_name=escape_sqlite(table), - where=where_clause, - ) - return await self.custom_sql(request, name, hash, sql, editable=True) - - extra_args = {} - # Handle ?_page_size=500 - page_size = request.raw_args.get('_size') - if page_size: - if page_size == 'max': - page_size = self.max_returned_rows - try: - page_size = int(page_size) - if page_size < 0: - raise ValueError - except ValueError: - raise DatasetteError( - '_size must be a positive integer', - status=400 - ) - if page_size > self.max_returned_rows: - raise DatasetteError( - '_size must be <= {}'.format(self.max_returned_rows), - status=400 - ) - extra_args['page_size'] = page_size - else: - page_size = self.page_size - - sql = 'select {select} from {table_name} {where}{order_by}limit {limit}{offset}'.format( - select=select, - table_name=escape_sqlite(table), - where=where_clause, - order_by=order_by, - limit=page_size + 1, - offset=offset, - ) - - if request.raw_args.get('_timelimit'): - extra_args['custom_time_limit'] = int(request.raw_args['_timelimit']) - - rows, truncated, description = await self.execute( - name, sql, params, truncate=True, **extra_args - ) - - # facets support - try: - facets = request.args['_facet'] - except KeyError: - facets = table_metadata.get('facets', []) - facet_results = {} - for column in facets: - facet_sql = ''' - select {col} as value, count(*) as count - {from_sql} - group by {col} order by count desc limit 20 - '''.format(col=escape_sqlite(column), from_sql=from_sql) - try: - facet_rows = await self.execute( - name, - facet_sql, - params, - truncate=False, - custom_time_limit=200 - ) - facet_results[column] = [{ - 'value': row['value'], - 'count': row['count'], - 'toggle_url': urllib.parse.urljoin( - request.url, path_with_added_args( - request, {column: row['value']} - ) - ) - } for row in facet_rows] - except sqlite3.OperationalError: - # Hit time limit - pass - - columns = [r[0] for r in description] - rows = list(rows) - - filter_columns = columns[:] - if use_rowid and filter_columns[0] == 'rowid': - filter_columns = filter_columns[1:] - - # Pagination next link - next_value = None - next_url = None - if len(rows) > page_size and page_size > 0: - if is_view: - next_value = int(_next or 0) + page_size - else: - next_value = path_from_row_pks(rows[-2], pks, use_rowid) - # If there's a sort or sort_desc, add that value as a prefix - if (sort or sort_desc) and not is_view: - prefix = rows[-2][sort or sort_desc] - if prefix is None: - prefix = '$null' - else: - prefix = urllib.parse.quote_plus(str(prefix)) - next_value = '{},{}'.format(prefix, next_value) - added_args = { - '_next': next_value, - } - if sort: - added_args['_sort'] = sort - else: - added_args['_sort_desc'] = sort_desc - else: - added_args = { - '_next': next_value, - } - next_url = urllib.parse.urljoin(request.url, path_with_added_args( - request, added_args - )) - rows = rows[:page_size] - - # Number of filtered rows in whole set: - filtered_table_rows_count = None - if count_sql: - try: - count_rows = list(await self.execute(name, count_sql, params)) - filtered_table_rows_count = count_rows[0][0] - except sqlite3.OperationalError: - # Almost certainly hit the timeout - pass - - # human_description_en combines filters AND search, if provided - human_description_en = filters.human_description_en(extra=search_descriptions) - - if sort or sort_desc: - sorted_by = 'sorted by {}{}'.format( - (sort or sort_desc), - ' descending' if sort_desc else '', - ) - human_description_en = ' '.join([ - b for b in [human_description_en, sorted_by] if b - ]) - - async def extra_template(): - display_columns, display_rows = await self.display_columns_and_rows( - name, table, description, rows, link_column=not is_view, expand_foreign_keys=True - ) - metadata = self.ds.metadata.get( - 'databases', {} - ).get(name, {}).get('tables', {}).get(table, {}) - self.ds.update_with_inherited_metadata(metadata) - return { - 'database_hash': hash, - 'supports_search': bool(fts_table), - 'search': search or '', - 'use_rowid': use_rowid, - 'filters': filters, - 'display_columns': display_columns, - 'filter_columns': filter_columns, - 'display_rows': display_rows, - 'is_sortable': any(c['sortable'] for c in display_columns), - 'path_with_added_args': path_with_added_args, - 'request': request, - 'sort': sort, - 'sort_desc': sort_desc, - 'disable_sort': is_view, - 'custom_rows_and_columns_templates': [ - '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns-table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns.html', - ], - 'metadata': metadata, - } - - return { - 'database': name, - 'table': table, - 'is_view': is_view, - 'view_definition': view_definition, - 'table_definition': table_definition, - 'human_description_en': human_description_en, - 'rows': rows[:page_size], - 'truncated': truncated, - 'table_rows_count': table_rows_count, - 'filtered_table_rows_count': filtered_table_rows_count, - 'columns': columns, - 'primary_keys': pks, - 'units': units, - 'query': { - 'sql': sql, - 'params': params, - }, - 'facet_results': facet_results, - 'next': next_value and str(next_value) or None, - 'next_url': next_url, - }, extra_template, ( - 'table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - 'table.html' - ) - - -class RowView(RowTableShared): - async def data(self, request, name, hash, table, pk_path): - table = urllib.parse.unquote_plus(table) - pk_values = urlsafe_components(pk_path) - info = self.ds.inspect()[name] - table_info = info['tables'].get(table) or {} - pks = table_info.get('primary_keys') or [] - use_rowid = not pks - select = '*' - if use_rowid: - select = 'rowid, *' - pks = ['rowid'] - wheres = [ - '"{}"=:p{}'.format(pk, i) - for i, pk in enumerate(pks) - ] - sql = 'select {} from "{}" where {}'.format( - select, table, ' AND '.join(wheres) - ) - params = {} - for i, pk_value in enumerate(pk_values): - params['p{}'.format(i)] = pk_value - # rows, truncated, description = await self.execute(name, sql, params, truncate=True) - rows, truncated, description = await self.execute(name, sql, params, truncate=True) - columns = [r[0] for r in description] - rows = list(rows) - if not rows: - raise NotFound('Record not found: {}'.format(pk_values)) - - async def template_data(): - display_columns, display_rows = await self.display_columns_and_rows( - name, table, description, rows, link_column=False, expand_foreign_keys=True - ) - for column in display_columns: - column['sortable'] = False - return { - 'database_hash': hash, - 'foreign_key_tables': await self.foreign_key_tables(name, table, pk_values), - 'display_columns': display_columns, - 'display_rows': display_rows, - 'custom_rows_and_columns_templates': [ - '_rows_and_columns-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns-row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - '_rows_and_columns.html', - ], - 'metadata': self.ds.metadata.get( - 'databases', {} - ).get(name, {}).get('tables', {}).get(table, {}), - } - - data = { - 'database': name, - 'table': table, - 'rows': rows, - 'columns': columns, - 'primary_keys': pks, - 'primary_key_values': pk_values, - 'units': self.table_metadata(name, table).get('units', {}) - } - - if 'foreign_key_tables' in (request.raw_args.get('_extras') or '').split(','): - data['foreign_key_tables'] = await self.foreign_key_tables(name, table, pk_values) - - return data, template_data, ( - 'row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), - 'row.html' - ) - - async def foreign_key_tables(self, name, table, pk_values): - if len(pk_values) != 1: - return [] - table_info = self.ds.inspect()[name]['tables'].get(table) - if not table_info: - return [] - foreign_keys = table_info['foreign_keys']['incoming'] - if len(foreign_keys) == 0: - return [] - - sql = 'select ' + ', '.join([ - '(select count(*) from {table} where "{column}"=:id)'.format( - table=escape_sqlite(fk['other_table']), - column=fk['other_column'], - ) - for fk in foreign_keys - ]) - try: - rows = list(await self.execute(name, sql, {'id': pk_values[0]})) - except sqlite3.OperationalError: - # Almost certainly hit the timeout - return [] - foreign_table_counts = dict( - zip( - [(fk['other_table'], fk['other_column']) for fk in foreign_keys], - list(rows[0]), - ) - ) - foreign_key_tables = [] - for fk in foreign_keys: - count = foreign_table_counts.get((fk['other_table'], fk['other_column'])) or 0 - foreign_key_tables.append({**fk, **{'count': count}}) - return foreign_key_tables + return response.text("") class Datasette: + def __init__( - self, files, num_threads=3, cache_headers=True, page_size=100, - max_returned_rows=1000, sql_time_limit_ms=1000, cors=False, - inspect_data=None, metadata=None, sqlite_extensions=None, - template_dir=None, plugins_dir=None, static_mounts=None): + self, + files, + num_threads=3, + cache_headers=True, + page_size=100, + max_returned_rows=1000, + sql_time_limit_ms=1000, + cors=False, + inspect_data=None, + metadata=None, + sqlite_extensions=None, + template_dir=None, + plugins_dir=None, + static_mounts=None, + ): self.files = files self.num_threads = num_threads - self.executor = futures.ThreadPoolExecutor( - max_workers=num_threads - ) + self.executor = futures.ThreadPoolExecutor(max_workers=num_threads) self.cache_headers = cache_headers self.page_size = page_size self.max_returned_rows = max_returned_rows @@ -1217,70 +109,68 @@ class Datasette: if self.plugins_dir: for filename in os.listdir(self.plugins_dir): filepath = os.path.join(self.plugins_dir, filename) - with open(filepath) as f: - mod = module_from_path(filepath, name=filename) - try: - pm.register(mod) - except ValueError: - # Plugin already registered - pass + mod = module_from_path(filepath, name=filename) + try: + pm.register(mod) + except ValueError: + # Plugin already registered + pass def app_css_hash(self): - if not hasattr(self, '_app_css_hash'): + if not hasattr(self, "_app_css_hash"): self._app_css_hash = hashlib.sha1( - open(os.path.join(str(app_root), 'datasette/static/app.css')).read().encode('utf8') - ).hexdigest()[:6] + open( + os.path.join(str(app_root), "datasette/static/app.css") + ).read().encode( + "utf8" + ) + ).hexdigest()[ + :6 + ] return self._app_css_hash def get_canned_query(self, database_name, query_name): - query = self.metadata.get( - 'databases', {} + query = self.metadata.get("databases", {}).get(database_name, {}).get( + "queries", {} ).get( - database_name, {} - ).get( - 'queries', {} - ).get(query_name) + query_name + ) if query: - return { - 'name': query_name, - 'sql': query, - } + return {"name": query_name, "sql": query} def asset_urls(self, key): urls_or_dicts = (self.metadata.get(key) or []) # Flatten list-of-lists from plugins: - urls_or_dicts += list( - itertools.chain.from_iterable(getattr(pm.hook, key)()) - ) + urls_or_dicts += list(itertools.chain.from_iterable(getattr(pm.hook, key)())) for url_or_dict in urls_or_dicts: if isinstance(url_or_dict, dict): - yield { - 'url': url_or_dict['url'], - 'sri': url_or_dict.get('sri'), - } + yield {"url": url_or_dict["url"], "sri": url_or_dict.get("sri")} + else: - yield { - 'url': url_or_dict, - } + yield {"url": url_or_dict} def extra_css_urls(self): - return self.asset_urls('extra_css_urls') + return self.asset_urls("extra_css_urls") def extra_js_urls(self): - return self.asset_urls('extra_js_urls') + return self.asset_urls("extra_js_urls") def update_with_inherited_metadata(self, metadata): # Fills in source/license with defaults, if available - metadata.update({ - 'source': metadata.get('source') or self.metadata.get('source'), - 'source_url': metadata.get('source_url') or self.metadata.get('source_url'), - 'license': metadata.get('license') or self.metadata.get('license'), - 'license_url': metadata.get('license_url') or self.metadata.get('license_url'), - }) + metadata.update( + { + "source": metadata.get("source") or self.metadata.get("source"), + "source_url": metadata.get("source_url") + or self.metadata.get("source_url"), + "license": metadata.get("license") or self.metadata.get("license"), + "license_url": metadata.get("license_url") + or self.metadata.get("license_url"), + } + ) def prepare_connection(self, conn): conn.row_factory = sqlite3.Row - conn.text_factory = lambda x: str(x, 'utf-8', 'replace') + conn.text_factory = lambda x: str(x, "utf-8", "replace") for name, num_args, func in self.sqlite_functions: conn.create_function(name, num_args, func) if self.sqlite_extensions: @@ -1296,31 +186,44 @@ class Datasette: path = Path(filename) name = path.stem if name in self._inspect: - raise Exception('Multiple files with same stem %s' % name) + raise Exception("Multiple files with same stem %s" % name) + # Calculate hash, efficiently m = hashlib.sha256() - with path.open('rb') as fp: + with path.open("rb") as fp: while True: data = fp.read(HASH_BLOCK_SIZE) if not data: break + m.update(data) # List tables and their row counts - database_metadata = self.metadata.get('databases', {}).get(name, {}) + database_metadata = self.metadata.get("databases", {}).get(name, {}) tables = {} views = [] - with sqlite3.connect('file:{}?immutable=1'.format(path), uri=True) as conn: + with sqlite3.connect( + "file:{}?immutable=1".format(path), uri=True + ) as conn: self.prepare_connection(conn) table_names = [ - r['name'] - for r in conn.execute('select * from sqlite_master where type="table"') + r["name"] + for r in conn.execute( + 'select * from sqlite_master where type="table"' + ) + ] + views = [ + v[0] + for v in conn.execute( + 'select name from sqlite_master where type = "view"' + ) ] - views = [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')] for table in table_names: try: count = conn.execute( - 'select count(*) from {}'.format(escape_sqlite(table)) - ).fetchone()[0] + "select count(*) from {}".format(escape_sqlite(table)) + ).fetchone()[ + 0 + ] except sqlite3.OperationalError: # This can happen when running against a FTS virtual tables # e.g. "select count(*) from some_fts;" @@ -1330,7 +233,8 @@ class Datasette: # Figure out primary keys table_info_rows = [ - row for row in conn.execute( + row + for row in conn.execute( 'PRAGMA table_info("{}")'.format(table) ).fetchall() if row[-1] @@ -1339,85 +243,97 @@ class Datasette: primary_keys = [str(r[1]) for r in table_info_rows] label_column = None # If table has two columns, one of which is ID, then label_column is the other one - column_names = [r[1] for r in conn.execute( - 'PRAGMA table_info({});'.format(escape_sqlite(table)) - ).fetchall()] - if column_names and len(column_names) == 2 and 'id' in column_names: - label_column = [c for c in column_names if c != 'id'][0] - table_metadata = database_metadata.get('tables', {}).get(table, {}) + column_names = [ + r[1] + for r in conn.execute( + "PRAGMA table_info({});".format(escape_sqlite(table)) + ).fetchall() + ] + if ( + column_names + and len(column_names) == 2 + and "id" in column_names + ): + label_column = [c for c in column_names if c != "id"][0] + table_metadata = database_metadata.get("tables", {}).get( + table, {} + ) tables[table] = { - 'name': table, - 'columns': column_names, - 'primary_keys': primary_keys, - 'count': count, - 'label_column': label_column, - 'hidden': table_metadata.get('hidden') or False, - 'fts_table': fts_table, + "name": table, + "columns": column_names, + "primary_keys": primary_keys, + "count": count, + "label_column": label_column, + "hidden": table_metadata.get("hidden") or False, + "fts_table": fts_table, } foreign_keys = get_all_foreign_keys(conn) for table, info in foreign_keys.items(): - tables[table]['foreign_keys'] = info + tables[table]["foreign_keys"] = info # Mark tables 'hidden' if they relate to FTS virtual tables hidden_tables = [ - r['name'] + r["name"] for r in conn.execute( - ''' + """ select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%' - ''' + """ ) ] if detect_spatialite(conn): # Also hide Spatialite internal tables hidden_tables += [ - 'ElementaryGeometries', 'SpatialIndex', 'geometry_columns', - 'spatial_ref_sys', 'spatialite_history', 'sql_statements_log', - 'sqlite_sequence', 'views_geometry_columns', 'virts_geometry_columns' + "ElementaryGeometries", + "SpatialIndex", + "geometry_columns", + "spatial_ref_sys", + "spatialite_history", + "sql_statements_log", + "sqlite_sequence", + "views_geometry_columns", + "virts_geometry_columns", ] + [ - r['name'] + r["name"] for r in conn.execute( - ''' + """ select name from sqlite_master where name like "idx_%" and type = "table" - ''' + """ ) ] for t in tables.keys(): for hidden_table in hidden_tables: if t == hidden_table or t.startswith(hidden_table): - tables[t]['hidden'] = True + tables[t]["hidden"] = True continue self._inspect[name] = { - 'hash': m.hexdigest(), - 'file': str(path), - 'tables': tables, - 'views': views, - + "hash": m.hexdigest(), + "file": str(path), + "tables": tables, + "views": views, } return self._inspect def register_custom_units(self): "Register any custom units defined in the metadata.json with Pint" - for unit in self.metadata.get('custom_units', []): + for unit in self.metadata.get("custom_units", []): ureg.define(unit) def versions(self): - conn = sqlite3.connect(':memory:') + conn = sqlite3.connect(":memory:") self.prepare_connection(conn) - sqlite_version = conn.execute( - 'select sqlite_version()' - ).fetchone()[0] + sqlite_version = conn.execute("select sqlite_version()").fetchone()[0] sqlite_extensions = {} for extension, testsql, hasversion in ( - ('json1', "SELECT json('{}')", False), - ('spatialite', "SELECT spatialite_version()", True), + ("json1", "SELECT json('{}')", False), + ("spatialite", "SELECT spatialite_version()", True), ): try: result = conn.execute(testsql) @@ -1429,106 +345,107 @@ class Datasette: pass # Figure out supported FTS versions fts_versions = [] - for fts in ('FTS5', 'FTS4', 'FTS3'): + for fts in ("FTS5", "FTS4", "FTS3"): try: conn.execute( - 'CREATE VIRTUAL TABLE v{fts} USING {fts} (t TEXT)'.format( - fts=fts - ) + "CREATE VIRTUAL TABLE v{fts} USING {fts} (t TEXT)".format(fts=fts) ) fts_versions.append(fts) except sqlite3.OperationalError: continue + return { - 'python': { - 'version': '.'.join(map(str, sys.version_info[:3])), - 'full': sys.version, + "python": { + "version": ".".join(map(str, sys.version_info[:3])), "full": sys.version }, - 'datasette': { - 'version': __version__, + "datasette": {"version": __version__}, + "sqlite": { + "version": sqlite_version, + "fts_versions": fts_versions, + "extensions": sqlite_extensions, }, - 'sqlite': { - 'version': sqlite_version, - 'fts_versions': fts_versions, - 'extensions': sqlite_extensions, - } } + def plugins(self): + return [ + { + "name": p["name"], + "static": p["static_path"] is not None, + "templates": p["templates_path"] is not None, + "version": p.get("version"), + } + for p in get_plugins(pm) + ] + def app(self): app = Sanic(__name__) - default_templates = str(app_root / 'datasette' / 'templates') + default_templates = str(app_root / "datasette" / "templates") template_paths = [] if self.template_dir: template_paths.append(self.template_dir) - template_paths.extend([ - plugin['templates_path'] - for plugin in get_plugins(pm) - if plugin['templates_path'] - ]) - template_paths.append(default_templates) - template_loader = ChoiceLoader([ - FileSystemLoader(template_paths), - # Support {% extends "default:table.html" %}: - PrefixLoader({ - 'default': FileSystemLoader(default_templates), - }, delimiter=':') - ]) - self.jinja_env = Environment( - loader=template_loader, - autoescape=True, + template_paths.extend( + [ + plugin["templates_path"] + for plugin in get_plugins(pm) + if plugin["templates_path"] + ] ) - self.jinja_env.filters['escape_css_string'] = escape_css_string - self.jinja_env.filters['quote_plus'] = lambda u: urllib.parse.quote_plus(u) - self.jinja_env.filters['escape_sqlite'] = escape_sqlite - self.jinja_env.filters['to_css_class'] = to_css_class + template_paths.append(default_templates) + template_loader = ChoiceLoader( + [ + FileSystemLoader(template_paths), + # Support {% extends "default:table.html" %}: + PrefixLoader( + {"default": FileSystemLoader(default_templates)}, delimiter=":" + ), + ] + ) + self.jinja_env = Environment(loader=template_loader, autoescape=True) + self.jinja_env.filters["escape_css_string"] = escape_css_string + self.jinja_env.filters["quote_plus"] = lambda u: urllib.parse.quote_plus(u) + self.jinja_env.filters["escape_sqlite"] = escape_sqlite + self.jinja_env.filters["to_css_class"] = to_css_class pm.hook.prepare_jinja2_environment(env=self.jinja_env) - app.add_route(IndexView.as_view(self), '/') + app.add_route(IndexView.as_view(self), "/") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires - app.add_route(favicon, '/favicon.ico') - app.static('/-/static/', str(app_root / 'datasette' / 'static')) + app.add_route(favicon, "/favicon.ico") + app.static("/-/static/", str(app_root / "datasette" / "static")) for path, dirname in self.static_mounts: app.static(path, dirname) # Mount any plugin static/ directories for plugin in get_plugins(pm): - if plugin['static_path']: - modpath = '/-/static-plugins/{}/'.format(plugin['name']) - app.static(modpath, plugin['static_path']) + if plugin["static_path"]: + modpath = "/-/static-plugins/{}/".format(plugin["name"]) + app.static(modpath, plugin["static_path"]) app.add_route( - JsonDataView.as_view(self, 'inspect.json', self.inspect), - '/-/inspect' + JsonDataView.as_view(self, "inspect.json", self.inspect), + "/-/inspect", ) app.add_route( - JsonDataView.as_view(self, 'metadata.json', lambda: self.metadata), - '/-/metadata' + JsonDataView.as_view(self, "metadata.json", lambda: self.metadata), + "/-/metadata", ) app.add_route( - JsonDataView.as_view(self, 'versions.json', self.versions), - '/-/versions' + JsonDataView.as_view(self, "versions.json", self.versions), + "/-/versions", ) app.add_route( - JsonDataView.as_view(self, 'plugins.json', lambda: [{ - 'name': p['name'], - 'static': p['static_path'] is not None, - 'templates': p['templates_path'] is not None, - 'version': p.get('version'), - } for p in get_plugins(pm)]), - '/-/plugins' + JsonDataView.as_view(self, "plugins.json", self.plugins), + "/-/plugins", ) app.add_route( - DatabaseView.as_view(self), - '/' + DatabaseView.as_view(self), "/" ) app.add_route( - DatabaseDownload.as_view(self), - '/' + DatabaseDownload.as_view(self), "/" ) app.add_route( TableView.as_view(self), - '//' + "//", ) app.add_route( RowView.as_view(self), - '///' + "///", ) self.register_custom_units() @@ -1554,17 +471,15 @@ class Datasette: info = {} message = str(exception) traceback.print_exc() - templates = ['500.html'] + templates = ["500.html"] if status != 500: - templates = ['{}.html'.format(status)] + templates - info.update({ - 'ok': False, - 'error': message, - 'status': status, - 'title': title, - }) - if (request.path.split('?')[0].endswith('.json')): + templates = ["{}.html".format(status)] + templates + info.update( + {"ok": False, "error": message, "status": status, "title": title} + ) + if request.path.split("?")[0].endswith(".json"): return response.json(info, status=status) + else: template = self.jinja_env.select_template(templates) return response.html(template.render(info), status=status) diff --git a/datasette/utils.py b/datasette/utils.py index 6138d353..118d9bd7 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -151,10 +151,9 @@ def path_with_added_args(request, args, path=None): args = args.items() arg_keys = set(a[0] for a in args) current = [] - for key, values in request.args.items(): - current.extend( - [(key, value) for value in values if key not in arg_keys] - ) + for key, value in urllib.parse.parse_qsl(request.query_string): + if key not in arg_keys: + current.append((key, value)) current.extend([ (key, value) for key, value in args diff --git a/datasette/views/__init__.py b/datasette/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/datasette/views/base.py b/datasette/views/base.py new file mode 100644 index 00000000..64a46808 --- /dev/null +++ b/datasette/views/base.py @@ -0,0 +1,373 @@ +import asyncio +import json +import re +import sqlite3 +import threading +import time + +import pint +from sanic import response +from sanic.exceptions import NotFound +from sanic.views import HTTPMethodView + +from datasette import __version__ +from datasette.utils import ( + CustomJSONEncoder, + InvalidSql, + path_from_row_pks, + path_with_added_args, + path_with_ext, + sqlite_timelimit, + to_css_class +) + +connections = threading.local() +ureg = pint.UnitRegistry() + +HASH_BLOCK_SIZE = 1024 * 1024 +HASH_LENGTH = 7 + + +class DatasetteError(Exception): + + def __init__(self, message, title=None, error_dict=None, status=500, template=None): + self.message = message + self.title = title + self.error_dict = error_dict or {} + self.status = status + + +class RenderMixin(HTTPMethodView): + + def render(self, templates, **context): + template = self.jinja_env.select_template(templates) + select_templates = [ + "{}{}".format("*" if template_name == template.name else "", template_name) + for template_name in templates + ] + return response.html( + template.render( + { + **context, + **{ + "app_css_hash": self.ds.app_css_hash(), + "select_templates": select_templates, + "zip": zip, + } + } + ) + ) + + +class BaseView(RenderMixin): + re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") + + def __init__(self, datasette): + self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + self.page_size = datasette.page_size + self.max_returned_rows = datasette.max_returned_rows + + def table_metadata(self, database, table): + "Fetch table-specific metadata." + return self.ds.metadata.get("databases", {}).get(database, {}).get( + "tables", {} + ).get( + table, {} + ) + + def options(self, request, *args, **kwargs): + r = response.text("ok") + if self.ds.cors: + r.headers["Access-Control-Allow-Origin"] = "*" + return r + + def redirect(self, request, path, forward_querystring=True): + if request.query_string and "?" not in path and forward_querystring: + path = "{}?{}".format(path, request.query_string) + r = response.redirect(path) + r.headers["Link"] = "<{}>; rel=preload".format(path) + if self.ds.cors: + r.headers["Access-Control-Allow-Origin"] = "*" + return r + + def resolve_db_name(self, db_name, **kwargs): + databases = self.ds.inspect() + hash = None + name = None + if "-" in db_name: + # Might be name-and-hash, or might just be + # a name with a hyphen in it + name, hash = db_name.rsplit("-", 1) + if name not in databases: + # Try the whole name + name = db_name + hash = None + else: + name = db_name + # Verify the hash + try: + info = databases[name] + except KeyError: + raise NotFound("Database not found: {}".format(name)) + + expected = info["hash"][:HASH_LENGTH] + if expected != hash: + should_redirect = "/{}-{}".format(name, expected) + if "table" in kwargs: + should_redirect += "/" + kwargs["table"] + if "pk_path" in kwargs: + should_redirect += "/" + kwargs["pk_path"] + if "as_json" in kwargs: + should_redirect += kwargs["as_json"] + if "as_db" in kwargs: + should_redirect += kwargs["as_db"] + return name, expected, should_redirect + + return name, expected, None + + async def execute( + self, + db_name, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + ): + """Executes sql against db_name in a thread""" + page_size = page_size or self.page_size + + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = self.ds.inspect()[db_name] + conn = sqlite3.connect( + "file:{}?immutable=1".format(info["file"]), + uri=True, + check_same_thread=False, + ) + self.ds.prepare_connection(conn) + setattr(connections, db_name, conn) + + time_limit_ms = self.ds.sql_time_limit_ms + if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms: + time_limit_ms = custom_time_limit + + with sqlite_timelimit(conn, time_limit_ms): + try: + cursor = conn.cursor() + cursor.execute(sql, params or {}) + max_returned_rows = self.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except Exception as e: + print( + "ERROR: conn={}, sql = {}, params = {}: {}".format( + conn, repr(sql), params, e + ) + ) + raise + + if truncate: + return rows, truncated, cursor.description + + else: + return rows + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) + + def get_templates(self, database, table=None): + assert NotImplemented + + async def get(self, request, db_name, **kwargs): + name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) + if should_redirect: + return self.redirect(request, should_redirect) + + return await self.view_get(request, name, hash, **kwargs) + + async def view_get(self, request, name, hash, **kwargs): + try: + as_json = kwargs.pop("as_json") + except KeyError: + as_json = False + extra_template_data = {} + start = time.time() + status_code = 200 + templates = [] + try: + response_or_template_contexts = await self.data( + request, name, hash, **kwargs + ) + if isinstance(response_or_template_contexts, response.HTTPResponse): + return response_or_template_contexts + + else: + data, extra_template_data, templates = response_or_template_contexts + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + + except (sqlite3.OperationalError) as e: + raise DatasetteError(str(e)) + + except DatasetteError: + raise + + end = time.time() + data["query_ms"] = (end - start) * 1000 + for key in ("source", "source_url", "license", "license_url"): + value = self.ds.metadata.get(key) + if value: + data[key] = value + if as_json: + # Special case for .jsono extension - redirect to _shape=objects + if as_json == ".jsono": + return self.redirect( + request, + path_with_added_args( + request, + {"_shape": "objects"}, + path=request.path.rsplit(".jsono", 1)[0] + ".json", + ), + forward_querystring=False, + ) + + # Deal with the _shape option + shape = request.args.get("_shape", "arrays") + if shape in ("objects", "object", "array"): + columns = data.get("columns") + rows = data.get("rows") + if rows and columns: + data["rows"] = [dict(zip(columns, row)) for row in rows] + if shape == "object": + error = None + if "primary_keys" not in data: + error = "_shape=object is only available on tables" + else: + pks = data["primary_keys"] + if not pks: + error = "_shape=object not available for tables with no primary keys" + else: + object_rows = {} + for row in data["rows"]: + pk_string = path_from_row_pks(row, pks, not pks) + object_rows[pk_string] = row + data = object_rows + if error: + data = { + "ok": False, + "error": error, + "database": name, + "database_hash": hash, + } + elif shape == "array": + data = data["rows"] + elif shape == "arrays": + pass + else: + status_code = 400 + data = { + "ok": False, + "error": "Invalid _shape: {}".format(shape), + "status": 400, + "title": None, + } + headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" + r = response.HTTPResponse( + json.dumps(data, cls=CustomJSONEncoder), + status=status_code, + content_type="application/json", + headers=headers, + ) + else: + extras = {} + if callable(extra_template_data): + extras = extra_template_data() + if asyncio.iscoroutine(extras): + extras = await extras + else: + extras = extra_template_data + context = { + **data, + **extras, + **{ + "url_json": path_with_ext(request, ".json"), + "url_jsono": path_with_ext(request, ".jsono"), + "extra_css_urls": self.ds.extra_css_urls(), + "extra_js_urls": self.ds.extra_js_urls(), + "datasette_version": __version__, + } + } + if "metadata" not in context: + context["metadata"] = self.ds.metadata + r = self.render(templates, **context) + r.status = status_code + # Set far-future cache expiry + if self.ds.cache_headers: + r.headers["Cache-Control"] = "max-age={}".format(365 * 24 * 60 * 60) + return r + + async def custom_sql( + self, request, name, hash, sql, editable=True, canned_query=None + ): + params = request.raw_args + if "sql" in params: + params.pop("sql") + if "_shape" in params: + params.pop("_shape") + # Extract any :named parameters + named_parameters = self.re_named_parameter.findall(sql) + named_parameter_values = { + named_parameter: params.get(named_parameter) or "" + for named_parameter in named_parameters + } + + # Set to blank string if missing from params + for named_parameter in named_parameters: + if named_parameter not in params: + params[named_parameter] = "" + + extra_args = {} + if params.get("_timelimit"): + extra_args["custom_time_limit"] = int(params["_timelimit"]) + rows, truncated, description = await self.execute( + name, sql, params, truncate=True, **extra_args + ) + columns = [r[0] for r in description] + + templates = ["query-{}.html".format(to_css_class(name)), "query.html"] + if canned_query: + templates.insert( + 0, + "query-{}-{}.html".format( + to_css_class(name), to_css_class(canned_query) + ), + ) + + return { + "database": name, + "rows": rows, + "truncated": truncated, + "columns": columns, + "query": {"sql": sql, "params": params}, + }, { + "database_hash": hash, + "custom_sql": True, + "named_parameter_values": named_parameter_values, + "editable": editable, + "canned_query": canned_query, + }, templates diff --git a/datasette/views/database.py b/datasette/views/database.py new file mode 100644 index 00000000..3ccd7ef0 --- /dev/null +++ b/datasette/views/database.py @@ -0,0 +1,50 @@ +import os + +from sanic import response + +from datasette.utils import to_css_class, validate_sql_select + +from .base import BaseView + + +class DatabaseView(BaseView): + + async def data(self, request, name, hash): + if request.args.get("sql"): + sql = request.raw_args.pop("sql") + validate_sql_select(sql) + return await self.custom_sql(request, name, hash, sql) + + info = self.ds.inspect()[name] + metadata = self.ds.metadata.get("databases", {}).get(name, {}) + self.ds.update_with_inherited_metadata(metadata) + tables = list(info["tables"].values()) + tables.sort(key=lambda t: (t["hidden"], t["name"])) + return { + "database": name, + "tables": tables, + "hidden_count": len([t for t in tables if t["hidden"]]), + "views": info["views"], + "queries": [ + {"name": query_name, "sql": query_sql} + for query_name, query_sql in (metadata.get("queries") or {}).items() + ], + }, { + "database_hash": hash, + "show_hidden": request.args.get("_show_hidden"), + "editable": True, + "metadata": metadata, + }, ( + "database-{}.html".format(to_css_class(name)), "database.html" + ) + + +class DatabaseDownload(BaseView): + + async def view_get(self, request, name, hash, **kwargs): + filepath = self.ds.inspect()[name]["file"] + return await response.file_stream( + filepath, + filename=os.path.basename(filepath), + mime_type="application/octet-stream", + ) diff --git a/datasette/views/index.py b/datasette/views/index.py new file mode 100644 index 00000000..c4ed3bef --- /dev/null +++ b/datasette/views/index.py @@ -0,0 +1,59 @@ +import json + +from sanic import response + +from datasette.utils import CustomJSONEncoder +from datasette.version import __version__ + +from .base import HASH_LENGTH, RenderMixin + + +class IndexView(RenderMixin): + + def __init__(self, datasette): + self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + + async def get(self, request, as_json): + databases = [] + for key, info in sorted(self.ds.inspect().items()): + tables = [t for t in info["tables"].values() if not t["hidden"]] + hidden_tables = [t for t in info["tables"].values() if t["hidden"]] + database = { + "name": key, + "hash": info["hash"], + "path": "{}-{}".format(key, info["hash"][:HASH_LENGTH]), + "tables_truncated": sorted( + tables, key=lambda t: t["count"], reverse=True + )[ + :5 + ], + "tables_count": len(tables), + "tables_more": len(tables) > 5, + "table_rows_sum": sum(t["count"] for t in tables), + "hidden_table_rows_sum": sum(t["count"] for t in hidden_tables), + "hidden_tables_count": len(hidden_tables), + "views_count": len(info["views"]), + } + databases.append(database) + if as_json: + headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" + return response.HTTPResponse( + json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder), + content_type="application/json", + headers=headers, + ) + + else: + return self.render( + ["index.html"], + databases=databases, + metadata=self.ds.metadata, + datasette_version=__version__, + extra_css_urls=self.ds.extra_css_urls(), + extra_js_urls=self.ds.extra_js_urls(), + ) diff --git a/datasette/views/table.py b/datasette/views/table.py new file mode 100644 index 00000000..5369825c --- /dev/null +++ b/datasette/views/table.py @@ -0,0 +1,763 @@ +import sqlite3 +import urllib + +import jinja2 +from sanic.exceptions import NotFound +from sanic.request import RequestParameters + +from datasette.utils import ( + Filters, + compound_keys_after_sql, + escape_sqlite, + filters_should_redirect, + is_url, + path_from_row_pks, + path_with_added_args, + to_css_class, + urlsafe_components +) + +from .base import BaseView, DatasetteError, ureg + + +class RowTableShared(BaseView): + + def sortable_columns_for_table(self, name, table, use_rowid): + table_metadata = self.table_metadata(name, table) + if "sortable_columns" in table_metadata: + sortable_columns = set(table_metadata["sortable_columns"]) + else: + table_info = self.ds.inspect()[name]["tables"].get(table) or {} + sortable_columns = set(table_info.get("columns", [])) + if use_rowid: + sortable_columns.add("rowid") + return sortable_columns + + async def display_columns_and_rows( + self, + database, + table, + description, + rows, + link_column=False, + expand_foreign_keys=True, + ): + "Returns columns, rows for specified table - including fancy foreign key treatment" + table_metadata = self.table_metadata(database, table) + info = self.ds.inspect()[database] + sortable_columns = self.sortable_columns_for_table(database, table, True) + columns = [ + {"name": r[0], "sortable": r[0] in sortable_columns} for r in description + ] + tables = info["tables"] + table_info = tables.get(table) or {} + pks = table_info.get("primary_keys") or [] + + # Prefetch foreign key resolutions for later expansion: + fks = {} + labeled_fks = {} + if table_info and expand_foreign_keys: + foreign_keys = table_info["foreign_keys"]["outgoing"] + for fk in foreign_keys: + label_column = ( + # First look in metadata.json definition for this foreign key table: + self.table_metadata(database, fk["other_table"]).get("label_column") + # Fall back to label_column from .inspect() detection: + or tables.get(fk["other_table"], {}).get("label_column") + ) + if not label_column: + # No label for this FK + fks[fk["column"]] = fk["other_table"] + continue + + ids_to_lookup = set([row[fk["column"]] for row in rows]) + sql = 'select "{other_column}", "{label_column}" from {other_table} where "{other_column}" in ({placeholders})'.format( + other_column=fk["other_column"], + label_column=label_column, + other_table=escape_sqlite(fk["other_table"]), + placeholders=", ".join(["?"] * len(ids_to_lookup)), + ) + try: + results = await self.execute( + database, sql, list(set(ids_to_lookup)) + ) + except sqlite3.OperationalError: + # Probably hit the timelimit + pass + else: + for id, value in results: + labeled_fks[(fk["column"], id)] = (fk["other_table"], value) + + cell_rows = [] + for row in rows: + cells = [] + # Unless we are a view, the first column is a link - either to the rowid + # or to the simple or compound primary key + if link_column: + cells.append( + { + "column": pks[0] if len(pks) == 1 else "Link", + "value": jinja2.Markup( + '{flat_pks}'.format( + database=database, + table=urllib.parse.quote_plus(table), + flat_pks=str( + jinja2.escape( + path_from_row_pks(row, pks, not pks, False) + ) + ), + flat_pks_quoted=path_from_row_pks(row, pks, not pks), + ) + ), + } + ) + + for value, column_dict in zip(row, columns): + column = column_dict["name"] + if link_column and len(pks) == 1 and column == pks[0]: + # If there's a simple primary key, don't repeat the value as it's + # already shown in the link column. + continue + + if (column, value) in labeled_fks: + other_table, label = labeled_fks[(column, value)] + display_value = jinja2.Markup( + '{label} {id}'.format( + database=database, + table=urllib.parse.quote_plus(other_table), + link_id=urllib.parse.quote_plus(str(value)), + id=str(jinja2.escape(value)), + label=str(jinja2.escape(label)), + ) + ) + elif column in fks: + display_value = jinja2.Markup( + '{id}'.format( + database=database, + table=urllib.parse.quote_plus(fks[column]), + link_id=urllib.parse.quote_plus(str(value)), + id=str(jinja2.escape(value)), + ) + ) + elif value is None: + display_value = jinja2.Markup(" ") + elif is_url(str(value).strip()): + display_value = jinja2.Markup( + '{url}'.format( + url=jinja2.escape(value.strip()) + ) + ) + elif column in table_metadata.get("units", {}) and value != "": + # Interpret units using pint + value = value * ureg(table_metadata["units"][column]) + # Pint uses floating point which sometimes introduces errors in the compact + # representation, which we have to round off to avoid ugliness. In the vast + # majority of cases this rounding will be inconsequential. I hope. + value = round(value.to_compact(), 6) + display_value = jinja2.Markup( + "{:~P}".format(value).replace(" ", " ") + ) + else: + display_value = str(value) + + cells.append({"column": column, "value": display_value}) + cell_rows.append(cells) + + if link_column: + # Add the link column header. + # If it's a simple primary key, we have to remove and re-add that column name at + # the beginning of the header row. + if len(pks) == 1: + columns = [col for col in columns if col["name"] != pks[0]] + + columns = [ + {"name": pks[0] if len(pks) == 1 else "Link", "sortable": len(pks) == 1} + ] + columns + return columns, cell_rows + + +class TableView(RowTableShared): + + async def data(self, request, name, hash, table): + table = urllib.parse.unquote_plus(table) + canned_query = self.ds.get_canned_query(name, table) + if canned_query is not None: + return await self.custom_sql( + request, + name, + hash, + canned_query["sql"], + editable=False, + canned_query=table, + ) + + is_view = bool( + list( + await self.execute( + name, + "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", + {"n": table}, + ) + )[ + 0 + ][ + 0 + ] + ) + view_definition = None + table_definition = None + if is_view: + view_definition = list( + await self.execute( + name, + 'select sql from sqlite_master where name = :n and type="view"', + {"n": table}, + ) + )[ + 0 + ][ + 0 + ] + else: + table_definition_rows = list( + await self.execute( + name, + 'select sql from sqlite_master where name = :n and type="table"', + {"n": table}, + ) + ) + if not table_definition_rows: + raise NotFound("Table not found: {}".format(table)) + + table_definition = table_definition_rows[0][0] + info = self.ds.inspect() + table_info = info[name]["tables"].get(table) or {} + pks = table_info.get("primary_keys") or [] + use_rowid = not pks and not is_view + if use_rowid: + select = "rowid, *" + order_by = "rowid" + order_by_pks = "rowid" + else: + select = "*" + order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) + order_by = order_by_pks + + if is_view: + order_by = "" + + # We roll our own query_string decoder because by default Sanic + # drops anything with an empty value e.g. ?name__exact= + args = RequestParameters( + urllib.parse.parse_qs(request.query_string, keep_blank_values=True) + ) + + # Special args start with _ and do not contain a __ + # That's so if there is a column that starts with _ + # it can still be queried using ?_col__exact=blah + special_args = {} + special_args_lists = {} + other_args = {} + for key, value in args.items(): + if key.startswith("_") and "__" not in key: + special_args[key] = value[0] + special_args_lists[key] = value + else: + other_args[key] = value[0] + + # Handle ?_filter_column and redirect, if present + redirect_params = filters_should_redirect(special_args) + if redirect_params: + return self.redirect( + request, + path_with_added_args(request, redirect_params), + forward_querystring=False, + ) + + # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) + if "_sort_by_desc" in special_args: + return self.redirect( + request, + path_with_added_args( + request, + { + "_sort_desc": special_args.get("_sort"), + "_sort_by_desc": None, + "_sort": None, + }, + ), + forward_querystring=False, + ) + + table_metadata = self.table_metadata(name, table) + units = table_metadata.get("units", {}) + filters = Filters(sorted(other_args.items()), units, ureg) + where_clauses, params = filters.build_where_clauses() + + # _search support: + fts_table = info[name]["tables"].get(table, {}).get("fts_table") + search_args = dict( + pair for pair in special_args.items() if pair[0].startswith("_search") + ) + search_descriptions = [] + search = "" + if fts_table and search_args: + if "_search" in search_args: + # Simple ?_search=xxx + search = search_args["_search"] + where_clauses.append( + "rowid in (select rowid from [{fts_table}] where [{fts_table}] match :search)".format( + fts_table=fts_table + ) + ) + search_descriptions.append('search matches "{}"'.format(search)) + params["search"] = search + else: + # More complex: search against specific columns + valid_columns = set(info[name]["tables"][fts_table]["columns"]) + for i, (key, search_text) in enumerate(search_args.items()): + search_col = key.split("_search_", 1)[1] + if search_col not in valid_columns: + raise DatasetteError("Cannot search by that column", status=400) + + where_clauses.append( + "rowid in (select rowid from [{fts_table}] where [{search_col}] match :search_{i})".format( + fts_table=fts_table, search_col=search_col, i=i + ) + ) + search_descriptions.append( + 'search column "{}" matches "{}"'.format( + search_col, search_text + ) + ) + params["search_{}".format(i)] = search_text + + table_rows_count = None + sortable_columns = set() + if not is_view: + table_rows_count = table_info["count"] + sortable_columns = self.sortable_columns_for_table(name, table, use_rowid) + + # Allow for custom sort order + sort = special_args.get("_sort") + if sort: + if sort not in sortable_columns: + raise DatasetteError("Cannot sort table by {}".format(sort)) + + order_by = escape_sqlite(sort) + sort_desc = special_args.get("_sort_desc") + if sort_desc: + if sort_desc not in sortable_columns: + raise DatasetteError("Cannot sort table by {}".format(sort_desc)) + + if sort: + raise DatasetteError("Cannot use _sort and _sort_desc at the same time") + + order_by = "{} desc".format(escape_sqlite(sort_desc)) + + from_sql = "from {table_name} {where}".format( + table_name=escape_sqlite(table), + where=( + "where {} ".format(" and ".join(where_clauses)) + ) if where_clauses else "", + ) + count_sql = "select count(*) {}".format(from_sql) + + _next = special_args.get("_next") + offset = "" + if _next: + if is_view: + # _next is an offset + offset = " offset {}".format(int(_next)) + else: + components = urlsafe_components(_next) + # If a sort order is applied, the first of these is the sort value + if sort or sort_desc: + sort_value = components[0] + # Special case for if non-urlencoded first token was $null + if _next.split(",")[0] == "$null": + sort_value = None + components = components[1:] + + # Figure out the SQL for next-based-on-primary-key first + next_by_pk_clauses = [] + if use_rowid: + next_by_pk_clauses.append("rowid > :p{}".format(len(params))) + params["p{}".format(len(params))] = components[0] + else: + # Apply the tie-breaker based on primary keys + if len(components) == len(pks): + param_len = len(params) + next_by_pk_clauses.append( + compound_keys_after_sql(pks, param_len) + ) + for i, pk_value in enumerate(components): + params["p{}".format(param_len + i)] = pk_value + + # Now add the sort SQL, which may incorporate next_by_pk_clauses + if sort or sort_desc: + if sort_value is None: + if sort_desc: + # Just items where column is null ordered by pk + where_clauses.append( + "({column} is null and {next_clauses})".format( + column=escape_sqlite(sort_desc), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} is not null or ({column} is null and {next_clauses}))".format( + column=escape_sqlite(sort), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( + column=escape_sqlite(sort or sort_desc), + op=">" if sort else "<", + p=len(params), + extra_desc_only="" if sort else " or {column2} is null".format( + column2=escape_sqlite(sort or sort_desc) + ), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + params["p{}".format(len(params))] = sort_value + order_by = "{}, {}".format(order_by, order_by_pks) + else: + where_clauses.extend(next_by_pk_clauses) + + where_clause = "" + if where_clauses: + where_clause = "where {} ".format(" and ".join(where_clauses)) + + if order_by: + order_by = "order by {} ".format(order_by) + + # _group_count=col1&_group_count=col2 + group_count = special_args_lists.get("_group_count") or [] + if group_count: + sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( + group_cols=", ".join( + '"{}"'.format(group_count_col) for group_count_col in group_count + ), + table_name=escape_sqlite(table), + where=where_clause, + ) + return await self.custom_sql(request, name, hash, sql, editable=True) + + extra_args = {} + # Handle ?_page_size=500 + page_size = request.raw_args.get("_size") + if page_size: + if page_size == "max": + page_size = self.max_returned_rows + try: + page_size = int(page_size) + if page_size < 0: + raise ValueError + + except ValueError: + raise DatasetteError("_size must be a positive integer", status=400) + + if page_size > self.max_returned_rows: + raise DatasetteError( + "_size must be <= {}".format(self.max_returned_rows), status=400 + ) + + extra_args["page_size"] = page_size + else: + page_size = self.page_size + + sql = "select {select} from {table_name} {where}{order_by}limit {limit}{offset}".format( + select=select, + table_name=escape_sqlite(table), + where=where_clause, + order_by=order_by, + limit=page_size + 1, + offset=offset, + ) + + if request.raw_args.get("_timelimit"): + extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) + + rows, truncated, description = await self.execute( + name, sql, params, truncate=True, **extra_args + ) + + # facets support + try: + facets = request.args["_facet"] + except KeyError: + facets = table_metadata.get("facets", []) + facet_results = {} + for column in facets: + facet_sql = """ + select {col} as value, count(*) as count + {from_sql} + group by {col} order by count desc limit 20 + """.format( + col=escape_sqlite(column), from_sql=from_sql + ) + try: + facet_rows = await self.execute( + name, facet_sql, params, truncate=False, custom_time_limit=200 + ) + facet_results[column] = [ + { + "value": row["value"], + "count": row["count"], + "toggle_url": urllib.parse.urljoin( + request.url, + path_with_added_args(request, {column: row["value"]}), + ), + } + for row in facet_rows + ] + except sqlite3.OperationalError: + # Hit time limit + pass + + columns = [r[0] for r in description] + rows = list(rows) + + filter_columns = columns[:] + if use_rowid and filter_columns[0] == "rowid": + filter_columns = filter_columns[1:] + + # Pagination next link + next_value = None + next_url = None + if len(rows) > page_size and page_size > 0: + if is_view: + next_value = int(_next or 0) + page_size + else: + next_value = path_from_row_pks(rows[-2], pks, use_rowid) + # If there's a sort or sort_desc, add that value as a prefix + if (sort or sort_desc) and not is_view: + prefix = rows[-2][sort or sort_desc] + if prefix is None: + prefix = "$null" + else: + prefix = urllib.parse.quote_plus(str(prefix)) + next_value = "{},{}".format(prefix, next_value) + added_args = {"_next": next_value} + if sort: + added_args["_sort"] = sort + else: + added_args["_sort_desc"] = sort_desc + else: + added_args = {"_next": next_value} + next_url = urllib.parse.urljoin( + request.url, path_with_added_args(request, added_args) + ) + rows = rows[:page_size] + + # Number of filtered rows in whole set: + filtered_table_rows_count = None + if count_sql: + try: + count_rows = list(await self.execute(name, count_sql, params)) + filtered_table_rows_count = count_rows[0][0] + except sqlite3.OperationalError: + # Almost certainly hit the timeout + pass + + # human_description_en combines filters AND search, if provided + human_description_en = filters.human_description_en(extra=search_descriptions) + + if sort or sort_desc: + sorted_by = "sorted by {}{}".format( + (sort or sort_desc), " descending" if sort_desc else "" + ) + human_description_en = " ".join( + [b for b in [human_description_en, sorted_by] if b] + ) + + async def extra_template(): + display_columns, display_rows = await self.display_columns_and_rows( + name, + table, + description, + rows, + link_column=not is_view, + expand_foreign_keys=True, + ) + metadata = self.ds.metadata.get("databases", {}).get(name, {}).get( + "tables", {} + ).get( + table, {} + ) + self.ds.update_with_inherited_metadata(metadata) + return { + "database_hash": hash, + "supports_search": bool(fts_table), + "search": search or "", + "use_rowid": use_rowid, + "filters": filters, + "display_columns": display_columns, + "filter_columns": filter_columns, + "display_rows": display_rows, + "is_sortable": any(c["sortable"] for c in display_columns), + "path_with_added_args": path_with_added_args, + "request": request, + "sort": sort, + "sort_desc": sort_desc, + "disable_sort": is_view, + "custom_rows_and_columns_templates": [ + "_rows_and_columns-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns-table-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns.html", + ], + "metadata": metadata, + } + + return { + "database": name, + "table": table, + "is_view": is_view, + "view_definition": view_definition, + "table_definition": table_definition, + "human_description_en": human_description_en, + "rows": rows[:page_size], + "truncated": truncated, + "table_rows_count": table_rows_count, + "filtered_table_rows_count": filtered_table_rows_count, + "columns": columns, + "primary_keys": pks, + "units": units, + "query": {"sql": sql, "params": params}, + "facet_results": facet_results, + "next": next_value and str(next_value) or None, + "next_url": next_url, + }, extra_template, ( + "table-{}-{}.html".format(to_css_class(name), to_css_class(table)), + "table.html", + ) + + +class RowView(RowTableShared): + + async def data(self, request, name, hash, table, pk_path): + table = urllib.parse.unquote_plus(table) + pk_values = urlsafe_components(pk_path) + info = self.ds.inspect()[name] + table_info = info["tables"].get(table) or {} + pks = table_info.get("primary_keys") or [] + use_rowid = not pks + select = "*" + if use_rowid: + select = "rowid, *" + pks = ["rowid"] + wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)] + sql = 'select {} from "{}" where {}'.format(select, table, " AND ".join(wheres)) + params = {} + for i, pk_value in enumerate(pk_values): + params["p{}".format(i)] = pk_value + # rows, truncated, description = await self.execute(name, sql, params, truncate=True) + rows, truncated, description = await self.execute( + name, sql, params, truncate=True + ) + columns = [r[0] for r in description] + rows = list(rows) + if not rows: + raise NotFound("Record not found: {}".format(pk_values)) + + async def template_data(): + display_columns, display_rows = await self.display_columns_and_rows( + name, + table, + description, + rows, + link_column=False, + expand_foreign_keys=True, + ) + for column in display_columns: + column["sortable"] = False + return { + "database_hash": hash, + "foreign_key_tables": await self.foreign_key_tables( + name, table, pk_values + ), + "display_columns": display_columns, + "display_rows": display_rows, + "custom_rows_and_columns_templates": [ + "_rows_and_columns-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns-row-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns.html", + ], + "metadata": self.ds.metadata.get("databases", {}).get(name, {}).get( + "tables", {} + ).get( + table, {} + ), + } + + data = { + "database": name, + "table": table, + "rows": rows, + "columns": columns, + "primary_keys": pks, + "primary_key_values": pk_values, + "units": self.table_metadata(name, table).get("units", {}), + } + + if "foreign_key_tables" in (request.raw_args.get("_extras") or "").split(","): + data["foreign_key_tables"] = await self.foreign_key_tables( + name, table, pk_values + ) + + return data, template_data, ( + "row-{}-{}.html".format(to_css_class(name), to_css_class(table)), "row.html" + ) + + async def foreign_key_tables(self, name, table, pk_values): + if len(pk_values) != 1: + return [] + + table_info = self.ds.inspect()[name]["tables"].get(table) + if not table_info: + return [] + + foreign_keys = table_info["foreign_keys"]["incoming"] + if len(foreign_keys) == 0: + return [] + + sql = "select " + ", ".join( + [ + '(select count(*) from {table} where "{column}"=:id)'.format( + table=escape_sqlite(fk["other_table"]), column=fk["other_column"] + ) + for fk in foreign_keys + ] + ) + try: + rows = list(await self.execute(name, sql, {"id": pk_values[0]})) + except sqlite3.OperationalError: + # Almost certainly hit the timeout + return [] + + foreign_table_counts = dict( + zip( + [(fk["other_table"], fk["other_column"]) for fk in foreign_keys], + list(rows[0]), + ) + ) + foreign_key_tables = [] + for fk in foreign_keys: + count = foreign_table_counts.get( + (fk["other_table"], fk["other_column"]) + ) or 0 + foreign_key_tables.append({**fk, **{"count": count}}) + return foreign_key_tables diff --git a/tests/test_utils.py b/tests/test_utils.py index cc987723..3f51f87d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,9 +29,12 @@ def test_urlsafe_components(path, expected): ('/foo?bar=1&bar=2', {'baz': 3}, '/foo?bar=1&bar=2&baz=3'), ('/foo?bar=1', {'bar': None}, '/foo'), # Test order is preserved - ('/?_facet=prim_state&_facet=area_name', { - 'prim_state': 'GA' - }, '/?_facet=prim_state&_facet=area_name&prim_state=GA'), + ('/?_facet=prim_state&_facet=area_name', ( + ('prim_state', 'GA'), + ), '/?_facet=prim_state&_facet=area_name&prim_state=GA'), + ('/?_facet=state&_facet=city&state=MI', ( + ('city', 'Detroit'), + ), '/?_facet=state&_facet=city&state=MI&city=Detroit'), ]) def test_path_with_added_args(path, added_args, expected): request = Request(