Load SQLite extensions inside .inspect() too

This commit is contained in:
Simon Willison 2017-11-26 14:51:42 -08:00
commit de6c62ed9a
No known key found for this signature in database
GPG key ID: FBB38AFE227189DB

View file

@ -115,16 +115,6 @@ class BaseView(HTTPMethodView):
return name, expected, should_redirect
return name, expected, None
def prepare_connection(self, conn):
conn.row_factory = sqlite3.Row
conn.text_factory = lambda x: str(x, 'utf-8', 'replace')
for name, num_args, func in self.ds.sqlite_functions:
conn.create_function(name, num_args, func)
if self.ds.sqlite_extensions:
conn.enable_load_extension(True)
for extension in self.ds.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None):
"""Executes sql against db_name in a thread"""
def sql_operation_in_thread():
@ -136,7 +126,7 @@ class BaseView(HTTPMethodView):
uri=True,
check_same_thread=False,
)
self.prepare_connection(conn)
self.ds.prepare_connection(conn)
setattr(connections, db_name, conn)
time_limit_ms = self.ds.sql_time_limit_ms
@ -783,6 +773,16 @@ class Datasette:
self.sqlite_functions = []
self.sqlite_extensions = sqlite_extensions or []
def prepare_connection(self, conn):
conn.row_factory = sqlite3.Row
conn.text_factory = lambda x: str(x, 'utf-8', 'replace')
for name, num_args, func in self.sqlite_functions:
conn.create_function(name, num_args, func)
if self.sqlite_extensions:
conn.enable_load_extension(True)
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension('{}')".format(extension))
def inspect(self):
if not self._inspect:
self._inspect = {}
@ -803,7 +803,7 @@ class Datasette:
tables = {}
views = []
with sqlite3.connect('file:{}?immutable=1'.format(path), uri=True) as conn:
conn.row_factory = sqlite3.Row
self.prepare_connection(conn)
table_names = [
r['name']
for r in conn.execute('select * from sqlite_master where type="table"')