Add new entrypoint option to --load-extensions. (#1789)

Thanks, @asg017
This commit is contained in:
Alex Garcia 2022-08-23 11:34:30 -07:00 committed by GitHub
commit 1d64c9a8da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 140 additions and 2 deletions

48
tests/ext.c Normal file
View file

@ -0,0 +1,48 @@
/*
** This file implements a SQLite extension with multiple entrypoints.
**
** The default entrypoint, sqlite3_ext_init, has a single function "a".
** The 1st alternate entrypoint, sqlite3_ext_b_init, has a single function "b".
** The 2nd alternate entrypoint, sqlite3_ext_c_init, has a single function "c".
**
** Compiling instructions:
** https://www.sqlite.org/loadext.html#compiling_a_loadable_extension
**
*/
#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1
// SQL function that returns back the value supplied during sqlite3_create_function()
static void func(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_text(context, (char *) sqlite3_user_data(context), -1, SQLITE_STATIC);
}
// The default entrypoint, since it matches the "ext.dylib"/"ext.so" name
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "a", 0, 0, "a", func, 0, 0);
}
// Alternate entrypoint #1
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_b_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "b", 0, 0, "b", func, 0, 0);
}
// Alternate entrypoint #2
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_c_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "c", 0, 0, "c", func, 0, 0);
}

View file

@ -0,0 +1,65 @@
from datasette.app import Datasette
import pytest
from pathlib import Path
# not necessarily a full path - the full compiled path looks like "ext.dylib"
# or another suffix, but sqlite will, under the hood, decide which file
# extension to use based on the operating system (apple=dylib, windows=dll etc)
# this resolves to "./ext", which is enough for SQLite to calculate the rest
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
# See if ext.c has been compiled, based off the different possible suffixes.
def has_compiled_ext():
for ext in ["dylib", "so", "dll"]:
path = Path(__file__).parent / f"ext.{ext}"
if path.is_file():
return True
return False
@pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_default_entrypoint():
# The default entrypoint only loads a() and NOT b() or c(), so those
# should fail.
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
response = await ds.client.get("/_memory.json?sql=select+a()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "a"
response = await ds.client.get("/_memory.json?sql=select+b()")
assert response.status_code == 400
assert response.json()["error"] == "no such function: b"
response = await ds.client.get("/_memory.json?sql=select+c()")
assert response.status_code == 400
assert response.json()["error"] == "no such function: c"
@pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_multiple_entrypoints():
# Load in the default entrypoint and the other 2 custom entrypoints, now
# all a(), b(), and c() should run successfully.
ds = Datasette(
sqlite_extensions=[
COMPILED_EXTENSION_PATH,
(COMPILED_EXTENSION_PATH, "sqlite3_ext_b_init"),
(COMPILED_EXTENSION_PATH, "sqlite3_ext_c_init"),
]
)
response = await ds.client.get("/_memory.json?sql=select+a()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "a"
response = await ds.client.get("/_memory.json?sql=select+b()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "b"
response = await ds.client.get("/_memory.json?sql=select+c()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "c"