Moved .execute() method from BaseView to Datasette class

Also introduced new Results() class with results.truncated, results.description, results.rows
This commit is contained in:
Simon Willison 2018-05-24 17:15:37 -07:00
commit 81df47e8d9
No known key found for this signature in database
GPG key ID: 17E2DEA2588B7F52
4 changed files with 105 additions and 90 deletions

View file

@ -1,3 +1,4 @@
import asyncio
import collections import collections
import hashlib import hashlib
import itertools import itertools
@ -5,6 +6,7 @@ import json
import os import os
import sqlite3 import sqlite3
import sys import sys
import threading
import traceback import traceback
import urllib.parse import urllib.parse
from concurrent import futures from concurrent import futures
@ -26,10 +28,13 @@ from .views.table import RowView, TableView
from . import hookspecs from . import hookspecs
from .utils import ( from .utils import (
InterruptedError,
Results,
escape_css_string, escape_css_string,
escape_sqlite, escape_sqlite,
get_plugins, get_plugins,
module_from_path, module_from_path,
sqlite_timelimit,
to_css_class to_css_class
) )
from .inspect import inspect_hash, inspect_views, inspect_tables from .inspect import inspect_hash, inspect_views, inspect_tables
@ -37,6 +42,7 @@ from .version import __version__
app_root = Path(__file__).parent.parent app_root = Path(__file__).parent.parent
connections = threading.local()
pm = pluggy.PluginManager("datasette") pm = pluggy.PluginManager("datasette")
pm.add_hookspecs(hookspecs) pm.add_hookspecs(hookspecs)
@ -285,6 +291,68 @@ class Datasette:
for p in get_plugins(pm) for p in get_plugins(pm)
] ]
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.inspect()[db_name]
conn = sqlite3.connect(
"file:{}?immutable=1".format(info["file"]),
uri=True,
check_same_thread=False,
)
self.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
raise InterruptedError(e)
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return Results(rows, truncated, cursor.description)
else:
return Results(rows, False, cursor.description)
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def app(self): def app(self):
app = Sanic(__name__) app = Sanic(__name__)
default_templates = str(app_root / "datasette" / "templates") default_templates = str(app_root / "datasette" / "templates")

View file

@ -36,6 +36,19 @@ class InterruptedError(Exception):
pass pass
class Results:
def __init__(self, rows, truncated, description):
self.rows = rows
self.truncated = truncated
self.description = description
def __iter__(self):
return iter(self.rows)
def __len__(self):
return len(self.rows)
def urlsafe_components(token): def urlsafe_components(token):
"Splits token on commas and URL decodes each component" "Splits token on commas and URL decodes each component"
return [ return [

View file

@ -2,7 +2,6 @@ import asyncio
import json import json
import re import re
import sqlite3 import sqlite3
import threading
import time import time
import pint import pint
@ -18,11 +17,9 @@ from datasette.utils import (
path_from_row_pks, path_from_row_pks,
path_with_added_args, path_with_added_args,
path_with_ext, path_with_ext,
sqlite_timelimit,
to_css_class to_css_class
) )
connections = threading.local()
ureg = pint.UnitRegistry() ureg = pint.UnitRegistry()
HASH_LENGTH = 7 HASH_LENGTH = 7
@ -128,68 +125,6 @@ class BaseView(RenderMixin):
return name, expected, None return name, expected, None
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.ds.inspect()[db_name]
conn = sqlite3.connect(
"file:{}?immutable=1".format(info["file"]),
uri=True,
check_same_thread=False,
)
self.ds.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.ds.sql_time_limit_ms
if custom_time_limit and custom_time_limit < self.ds.sql_time_limit_ms:
time_limit_ms = custom_time_limit
with sqlite_timelimit(conn, time_limit_ms):
try:
cursor = conn.cursor()
cursor.execute(sql, params or {})
max_returned_rows = self.max_returned_rows
if max_returned_rows == page_size:
max_returned_rows += 1
if max_returned_rows and truncate:
rows = cursor.fetchmany(max_returned_rows + 1)
truncated = len(rows) > max_returned_rows
rows = rows[:max_returned_rows]
else:
rows = cursor.fetchall()
truncated = False
except sqlite3.OperationalError as e:
if e.args == ('interrupted',):
raise InterruptedError(e)
print(
"ERROR: conn={}, sql = {}, params = {}: {}".format(
conn, repr(sql), params, e
)
)
raise
if truncate:
return rows, truncated, cursor.description
else:
return rows
return await asyncio.get_event_loop().run_in_executor(
self.executor, sql_operation_in_thread
)
def get_templates(self, database, table=None): def get_templates(self, database, table=None):
assert NotImplemented assert NotImplemented
@ -348,10 +283,10 @@ class BaseView(RenderMixin):
extra_args = {} extra_args = {}
if params.get("_timelimit"): if params.get("_timelimit"):
extra_args["custom_time_limit"] = int(params["_timelimit"]) extra_args["custom_time_limit"] = int(params["_timelimit"])
rows, truncated, description = await self.execute( results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args name, sql, params, truncate=True, **extra_args
) )
columns = [r[0] for r in description] columns = [r[0] for r in results.description]
templates = ["query-{}.html".format(to_css_class(name)), "query.html"] templates = ["query-{}.html".format(to_css_class(name)), "query.html"]
if canned_query: if canned_query:
@ -364,8 +299,8 @@ class BaseView(RenderMixin):
return { return {
"database": name, "database": name,
"rows": rows, "rows": results.rows,
"truncated": truncated, "truncated": results.truncated,
"columns": columns, "columns": columns,
"query": {"sql": sql, "params": params}, "query": {"sql": sql, "params": params},
}, { }, {

View file

@ -73,7 +73,7 @@ class RowTableShared(BaseView):
placeholders=", ".join(["?"] * len(set(values))), placeholders=", ".join(["?"] * len(set(values))),
) )
try: try:
results = await self.execute( results = await self.ds.execute(
database, sql, list(set(values)) database, sql, list(set(values))
) )
except InterruptedError: except InterruptedError:
@ -132,7 +132,7 @@ class RowTableShared(BaseView):
placeholders=", ".join(["?"] * len(ids_to_lookup)), placeholders=", ".join(["?"] * len(ids_to_lookup)),
) )
try: try:
results = await self.execute( results = await self.ds.execute(
database, sql, list(set(ids_to_lookup)) database, sql, list(set(ids_to_lookup))
) )
except InterruptedError: except InterruptedError:
@ -246,7 +246,7 @@ class TableView(RowTableShared):
is_view = bool( is_view = bool(
list( list(
await self.execute( await self.ds.execute(
name, name,
"SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n", "SELECT count(*) from sqlite_master WHERE type = 'view' and name=:n",
{"n": table}, {"n": table},
@ -257,7 +257,7 @@ class TableView(RowTableShared):
table_definition = None table_definition = None
if is_view: if is_view:
view_definition = list( view_definition = list(
await self.execute( await self.ds.execute(
name, name,
'select sql from sqlite_master where name = :n and type="view"', 'select sql from sqlite_master where name = :n and type="view"',
{"n": table}, {"n": table},
@ -265,7 +265,7 @@ class TableView(RowTableShared):
)[0][0] )[0][0]
else: else:
table_definition_rows = list( table_definition_rows = list(
await self.execute( await self.ds.execute(
name, name,
'select sql from sqlite_master where name = :n and type="table"', 'select sql from sqlite_master where name = :n and type="table"',
{"n": table}, {"n": table},
@ -534,7 +534,7 @@ class TableView(RowTableShared):
if request.raw_args.get("_timelimit"): if request.raw_args.get("_timelimit"):
extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"]) extra_args["custom_time_limit"] = int(request.raw_args["_timelimit"])
rows, truncated, description = await self.execute( results = await self.ds.execute(
name, sql, params, truncate=True, **extra_args name, sql, params, truncate=True, **extra_args
) )
@ -560,7 +560,7 @@ class TableView(RowTableShared):
limit=facet_size+1, limit=facet_size+1,
) )
try: try:
facet_rows = await self.execute( facet_rows_results = await self.ds.execute(
name, facet_sql, params, name, facet_sql, params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config["facet_time_limit_ms"], custom_time_limit=self.ds.config["facet_time_limit_ms"],
@ -569,9 +569,9 @@ class TableView(RowTableShared):
facet_results[column] = { facet_results[column] = {
"name": column, "name": column,
"results": facet_results_values, "results": facet_results_values,
"truncated": len(facet_rows) > facet_size, "truncated": len(facet_rows_results) > facet_size,
} }
facet_rows = facet_rows[:facet_size] facet_rows = facet_rows_results.rows[:facet_size]
# Attempt to expand foreign keys into labels # Attempt to expand foreign keys into labels
values = [row["value"] for row in facet_rows] values = [row["value"] for row in facet_rows]
expanded = (await self.expand_foreign_keys( expanded = (await self.expand_foreign_keys(
@ -602,8 +602,8 @@ class TableView(RowTableShared):
except InterruptedError: except InterruptedError:
facets_timed_out.append(column) facets_timed_out.append(column)
columns = [r[0] for r in description] columns = [r[0] for r in results.description]
rows = list(rows) rows = list(results.rows)
filter_columns = columns[:] filter_columns = columns[:]
if use_rowid and filter_columns[0] == "rowid": if use_rowid and filter_columns[0] == "rowid":
@ -641,7 +641,7 @@ class TableView(RowTableShared):
filtered_table_rows_count = None filtered_table_rows_count = None
if count_sql: if count_sql:
try: try:
count_rows = list(await self.execute( count_rows = list(await self.ds.execute(
name, count_sql, from_sql_params name, count_sql, from_sql_params
)) ))
filtered_table_rows_count = count_rows[0][0] filtered_table_rows_count = count_rows[0][0]
@ -665,7 +665,7 @@ class TableView(RowTableShared):
) )
distinct_values = None distinct_values = None
try: try:
distinct_values = await self.execute( distinct_values = await self.ds.execute(
name, suggested_facet_sql, from_sql_params, name, suggested_facet_sql, from_sql_params,
truncate=False, truncate=False,
custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"], custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"],
@ -701,7 +701,7 @@ class TableView(RowTableShared):
display_columns, display_rows = await self.display_columns_and_rows( display_columns, display_rows = await self.display_columns_and_rows(
name, name,
table, table,
description, results.description,
rows, rows,
link_column=not is_view, link_column=not is_view,
expand_foreign_keys=True, expand_foreign_keys=True,
@ -755,7 +755,7 @@ class TableView(RowTableShared):
"table_definition": table_definition, "table_definition": table_definition,
"human_description_en": human_description_en, "human_description_en": human_description_en,
"rows": rows[:page_size], "rows": rows[:page_size],
"truncated": truncated, "truncated": results.truncated,
"table_rows_count": table_rows_count, "table_rows_count": table_rows_count,
"filtered_table_rows_count": filtered_table_rows_count, "filtered_table_rows_count": filtered_table_rows_count,
"columns": columns, "columns": columns,
@ -790,12 +790,11 @@ class RowView(RowTableShared):
params = {} params = {}
for i, pk_value in enumerate(pk_values): for i, pk_value in enumerate(pk_values):
params["p{}".format(i)] = pk_value params["p{}".format(i)] = pk_value
# rows, truncated, description = await self.execute(name, sql, params, truncate=True) results = await self.ds.execute(
rows, truncated, description = await self.execute(
name, sql, params, truncate=True name, sql, params, truncate=True
) )
columns = [r[0] for r in description] columns = [r[0] for r in results.description]
rows = list(rows) rows = list(results.rows)
if not rows: if not rows:
raise NotFound("Record not found: {}".format(pk_values)) raise NotFound("Record not found: {}".format(pk_values))
@ -803,7 +802,7 @@ class RowView(RowTableShared):
display_columns, display_rows = await self.display_columns_and_rows( display_columns, display_rows = await self.display_columns_and_rows(
name, name,
table, table,
description, results.description,
rows, rows,
link_column=False, link_column=False,
expand_foreign_keys=True, expand_foreign_keys=True,
@ -874,7 +873,7 @@ class RowView(RowTableShared):
] ]
) )
try: try:
rows = list(await self.execute(name, sql, {"id": pk_values[0]})) rows = list(await self.ds.execute(name, sql, {"id": pk_values[0]}))
except sqlite3.OperationalError: except sqlite3.OperationalError:
# Almost certainly hit the timeout # Almost certainly hit the timeout
return [] return []