From 0209a0a344503157351e625f0629b686961763c9 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 31 Mar 2019 11:02:22 -0700 Subject: [PATCH] table_exists() now uses async SQL, refs #420 --- datasette/app.py | 46 ++++++++++++++++++++++++++--------------- datasette/utils.py | 5 +++-- datasette/views/base.py | 16 ++++++++------ setup.py | 1 + tests/test_utils.py | 7 ++++--- 5 files changed, 47 insertions(+), 28 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index fcffb36a..d5994e5e 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -300,8 +300,13 @@ class Datasette: conn.execute('PRAGMA cache_size=-{}'.format(self.config("cache_size_kb"))) pm.hook.prepare_connection(conn=conn) - def table_exists(self, database, table): - return table in self.inspect().get(database, {}).get("tables") + async def table_exists(self, database, table): + results = await self.execute( + database, + "select 1 from sqlite_master where type='table' and name=?", + params=(table,) + ) + return bool(results.rows) def inspect(self): " Inspect the database and return a dictionary of table metadata " @@ -410,19 +415,8 @@ class Datasette: for p in ps ] - async def execute( - self, - db_name, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - ): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def sql_operation_in_thread(): + async def execute_against_connection_in_thread(self, db_name, fn): + def in_thread(): conn = getattr(connections, db_name, None) if not conn: info = self.inspect()[db_name] @@ -441,7 +435,25 @@ class Datasette: ) self.prepare_connection(conn) setattr(connections, db_name, conn) + return fn(conn) + return await asyncio.get_event_loop().run_in_executor( + self.executor, in_thread + ) + + async def execute( + self, + db_name, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + ): + """Executes sql against db_name in a thread""" + page_size = page_size or self.page_size + + def sql_operation_in_thread(conn): time_limit_ms = self.sql_time_limit_ms if custom_time_limit and custom_time_limit < time_limit_ms: time_limit_ms = custom_time_limit @@ -476,8 +488,8 @@ class Datasette: else: return Results(rows, False, cursor.description) - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread + return await self.execute_against_connection_in_thread( + db_name, sql_operation_in_thread ) def app(self): diff --git a/datasette/utils.py b/datasette/utils.py index 8bcaefc2..3a7e90c4 100644 --- a/datasette/utils.py +++ b/datasette/utils.py @@ -801,10 +801,11 @@ def get_plugins(pm): FORMATS = ('csv', 'json', 'jsono') -def resolve_table_and_format(table_and_format, table_exists): +async def resolve_table_and_format(table_and_format, table_exists): if '.' in table_and_format: # Check if a table exists with this exact name - if table_exists(table_and_format): + it_exists = await table_exists(table_and_format) + if it_exists: return table_and_format, None # Check if table ends with a known format for _format in FORMATS: diff --git a/datasette/views/base.py b/datasette/views/base.py index 33b7524e..e79a46ec 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -155,7 +155,7 @@ class BaseView(RenderMixin): r.headers["Access-Control-Allow-Origin"] = "*" return r - def resolve_db_name(self, request, db_name, **kwargs): + async def resolve_db_name(self, request, db_name, **kwargs): databases = self.ds.inspect() hash = None name = None @@ -180,11 +180,13 @@ class BaseView(RenderMixin): if not correct_hash_provided: if "table_and_format" in kwargs: - table, _format = resolve_table_and_format( + async def async_table_exists(t): + return await self.ds.table_exists(name, t) + table, _format = await resolve_table_and_format( table_and_format=urllib.parse.unquote_plus( kwargs["table_and_format"] ), - table_exists=lambda t: self.ds.table_exists(name, t) + table_exists=async_table_exists ) kwargs["table"] = table if _format: @@ -221,7 +223,7 @@ class BaseView(RenderMixin): assert NotImplemented async def get(self, request, db_name, **kwargs): - database, hash, correct_hash_provided, should_redirect = self.resolve_db_name( + database, hash, correct_hash_provided, should_redirect = await self.resolve_db_name( request, db_name, **kwargs ) if should_redirect: @@ -328,11 +330,13 @@ class BaseView(RenderMixin): if not _format: _format = (kwargs.pop("as_format", None) or "").lstrip(".") if "table_and_format" in kwargs: - table, _ext_format = resolve_table_and_format( + async def async_table_exists(t): + return await self.ds.table_exists(database, t) + table, _ext_format = await resolve_table_and_format( table_and_format=urllib.parse.unquote_plus( kwargs["table_and_format"] ), - table_exists=lambda t: self.ds.table_exists(database, t) + table_exists=async_table_exists ) _format = _format or _ext_format kwargs["table"] = table diff --git a/setup.py b/setup.py index fb00f2d0..524f5243 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ setup( extras_require={ 'test': [ 'pytest==4.0.2', + 'pytest-asyncio==0.10.0', 'aiohttp==3.5.3', 'beautifulsoup4==4.6.1', ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 1f0079c9..fad1ac84 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -332,10 +332,11 @@ def test_compound_keys_after_sql(): '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) -def table_exists(table): +async def table_exists(table): return table == "exists.csv" +@pytest.mark.asyncio @pytest.mark.parametrize( "table_and_format,expected_table,expected_format", [ @@ -346,10 +347,10 @@ def table_exists(table): ("exists.csv", "exists.csv", None), ], ) -def test_resolve_table_and_format( +async def test_resolve_table_and_format( table_and_format, expected_table, expected_format ): - actual_table, actual_format = utils.resolve_table_and_format( + actual_table, actual_format = await utils.resolve_table_and_format( table_and_format, table_exists ) assert expected_table == actual_table