table_exists() now uses async SQL, refs #420

This commit is contained in:
Simon Willison 2019-03-31 11:02:22 -07:00
commit 0209a0a344
5 changed files with 47 additions and 28 deletions

View file

@ -300,8 +300,13 @@ class Datasette:
conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb"))) conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb")))
pm.hook.prepare_connection(conn=conn) pm.hook.prepare_connection(conn=conn)
def table_exists(self, database, table): async def table_exists(self, database, table):
return table in self.inspect().get(database, {}).get("tables") results = await self.execute(
database,
"select 1 from sqlite_master where type='table' and name=?",
params=(table,)
)
return bool(results.rows)
def inspect(self): def inspect(self):
" Inspect the database and return a dictionary of table metadata " " Inspect the database and return a dictionary of table metadata "
@ -410,19 +415,8 @@ class Datasette:
for p in ps for p in ps
] ]
async def execute( async def execute_against_connection_in_thread(self, db_name, fn):
self, def in_thread():
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread():
conn = getattr(connections, db_name, None) conn = getattr(connections, db_name, None)
if not conn: if not conn:
info = self.inspect()[db_name] info = self.inspect()[db_name]
@ -441,7 +435,25 @@ class Datasette:
) )
self.prepare_connection(conn) self.prepare_connection(conn)
setattr(connections, db_name, conn) setattr(connections, db_name, conn)
return fn(conn)
return await asyncio.get_event_loop().run_in_executor(
self.executor, in_thread
)
async def execute(
self,
db_name,
sql,
params=None,
truncate=False,
custom_time_limit=None,
page_size=None,
):
"""Executes sql against db_name in a thread"""
page_size = page_size or self.page_size
def sql_operation_in_thread(conn):
time_limit_ms = self.sql_time_limit_ms time_limit_ms = self.sql_time_limit_ms
if custom_time_limit and custom_time_limit < time_limit_ms: if custom_time_limit and custom_time_limit < time_limit_ms:
time_limit_ms = custom_time_limit time_limit_ms = custom_time_limit
@ -476,8 +488,8 @@ class Datasette:
else: else:
return Results(rows, False, cursor.description) return Results(rows, False, cursor.description)
return await asyncio.get_event_loop().run_in_executor( return await self.execute_against_connection_in_thread(
self.executor, sql_operation_in_thread db_name, sql_operation_in_thread
) )
def app(self): def app(self):

View file

@ -801,10 +801,11 @@ def get_plugins(pm):
FORMATS = ('csv', 'json', 'jsono') FORMATS = ('csv', 'json', 'jsono')
def resolve_table_and_format(table_and_format, table_exists): async def resolve_table_and_format(table_and_format, table_exists):
if '.' in table_and_format: if '.' in table_and_format:
# Check if a table exists with this exact name # Check if a table exists with this exact name
if table_exists(table_and_format): it_exists = await table_exists(table_and_format)
if it_exists:
return table_and_format, None return table_and_format, None
# Check if table ends with a known format # Check if table ends with a known format
for _format in FORMATS: for _format in FORMATS:

View file

@ -155,7 +155,7 @@ class BaseView(RenderMixin):
r.headers["Access-Control-Allow-Origin"] = "*" r.headers["Access-Control-Allow-Origin"] = "*"
return r return r
def resolve_db_name(self, request, db_name, **kwargs): async def resolve_db_name(self, request, db_name, **kwargs):
databases = self.ds.inspect() databases = self.ds.inspect()
hash = None hash = None
name = None name = None
@ -180,11 +180,13 @@ class BaseView(RenderMixin):
if not correct_hash_provided: if not correct_hash_provided:
if "table_and_format" in kwargs: if "table_and_format" in kwargs:
table, _format = resolve_table_and_format( async def async_table_exists(t):
return await self.ds.table_exists(name, t)
table, _format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus( table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"] kwargs["table_and_format"]
), ),
table_exists=lambda t: self.ds.table_exists(name, t) table_exists=async_table_exists
) )
kwargs["table"] = table kwargs["table"] = table
if _format: if _format:
@ -221,7 +223,7 @@ class BaseView(RenderMixin):
assert NotImplemented assert NotImplemented
async def get(self, request, db_name, **kwargs): async def get(self, request, db_name, **kwargs):
database, hash, correct_hash_provided, should_redirect = self.resolve_db_name( database, hash, correct_hash_provided, should_redirect = await self.resolve_db_name(
request, db_name, **kwargs request, db_name, **kwargs
) )
if should_redirect: if should_redirect:
@ -328,11 +330,13 @@ class BaseView(RenderMixin):
if not _format: if not _format:
_format = (kwargs.pop("as_format", None) or "").lstrip(".") _format = (kwargs.pop("as_format", None) or "").lstrip(".")
if "table_and_format" in kwargs: if "table_and_format" in kwargs:
table, _ext_format = resolve_table_and_format( async def async_table_exists(t):
return await self.ds.table_exists(database, t)
table, _ext_format = await resolve_table_and_format(
table_and_format=urllib.parse.unquote_plus( table_and_format=urllib.parse.unquote_plus(
kwargs["table_and_format"] kwargs["table_and_format"]
), ),
table_exists=lambda t: self.ds.table_exists(database, t) table_exists=async_table_exists
) )
_format = _format or _ext_format _format = _format or _ext_format
kwargs["table"] = table kwargs["table"] = table

View file

@ -50,6 +50,7 @@ setup(
extras_require={ extras_require={
'test': [ 'test': [
'pytest==4.0.2', 'pytest==4.0.2',
'pytest-asyncio==0.10.0',
'aiohttp==3.5.3', 'aiohttp==3.5.3',
'beautifulsoup4==4.6.1', 'beautifulsoup4==4.6.1',
] ]

View file

@ -332,10 +332,11 @@ def test_compound_keys_after_sql():
'''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c'])
def table_exists(table): async def table_exists(table):
return table == "exists.csv" return table == "exists.csv"
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"table_and_format,expected_table,expected_format", "table_and_format,expected_table,expected_format",
[ [
@ -346,10 +347,10 @@ def table_exists(table):
("exists.csv", "exists.csv", None), ("exists.csv", "exists.csv", None),
], ],
) )
def test_resolve_table_and_format( async def test_resolve_table_and_format(
table_and_format, expected_table, expected_format table_and_format, expected_table, expected_format
): ):
actual_table, actual_format = utils.resolve_table_and_format( actual_table, actual_format = await utils.resolve_table_and_format(
table_and_format, table_exists table_and_format, table_exists
) )
assert expected_table == actual_table assert expected_table == actual_table