Skip to content

Commit

Permalink
default models, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Aug 24, 2024
1 parent 02821ba commit 1d7de47
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 13 deletions.
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,16 @@ sqlite-lembed.h: sqlite-lembed.h.tmpl VERSION
SOURCE=$(shell git log -n 1 --pretty=format:%H -- VERSION) \
envsubst < $< > $@

test-loadable:
echo 4
MODELS_DIR=$(prefix)/.models

$(MODELS_DIR): $(BUILD_DIR)
mkdir -p $@

$(MODELS_DIR)/all-MiniLM-L6-v2.e4ce9877.q8_0.gguf: $(MODELS_DIR)
curl -L -o $@ https://huggingface.co/asg017/sqlite-lembed-model-examples/resolve/main/all-MiniLM-L6-v2/all-MiniLM-L6-v2.e4ce9877.q8_0.gguf

test-loadable: $(TARGET_LOADABLE) $(MODELS_DIR)/all-MiniLM-L6-v2.e4ce9877.q8_0.gguf
python -m pytest tests/test-loadable.py


FORMAT_FILES=sqlite-lembed.c
Expand Down
34 changes: 24 additions & 10 deletions sqlite-lembed.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,32 @@ int api_model_from_name(struct Api *api, const char *name, int name_length,
static void lembed(sqlite3_context *context, int argc, sqlite3_value **argv) {
struct llama_model *model;
struct llama_context *ctx;
int rc = api_model_from_name((struct Api *)sqlite3_user_data(context),
int rc;
const char * input;
sqlite3_int64 input_len;
if(argc == 1) {
input = (const char *)sqlite3_value_text(argv[0]);
input_len = sqlite3_value_bytes(argv[0]);
rc = api_model_from_name((struct Api *)sqlite3_user_data(context), "default", strlen("default"), &model, &ctx);
if(rc != SQLITE_OK) {
sqlite3_result_error(context, "No default model has been registered yet with lembed_models", -1);
return;
}
}else {
input = (const char *)sqlite3_value_text(argv[1]);
input_len = sqlite3_value_bytes(argv[1]);
rc = api_model_from_name((struct Api *)sqlite3_user_data(context),
(const char *)sqlite3_value_text(argv[0]),
sqlite3_value_bytes(argv[0]), &model, &ctx);
if(rc != SQLITE_OK) {
sqlite3_result_error(context, "Unknown model name. Was it registered with lembed_models?", -1);
return;

if(rc != SQLITE_OK) {
char * zSql = sqlite3_mprintf("Unknown model name '%s'. Was it registered with lembed_models?", sqlite3_value_text(argv[0]));
sqlite3_result_error(context, zSql, -1);
sqlite3_free(zSql);
return;
}
}
const char *input = (const char *)sqlite3_value_text(argv[1]);
sqlite3_int64 input_len = sqlite3_value_bytes(argv[1]);

int dimensions;
float *embedding;
rc = embed_single(model, ctx, input, input_len, &embedding, &dimensions);
Expand Down Expand Up @@ -483,10 +500,6 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,
}
p->api->models[idx].model = model;
p->api->models[idx].context = ctx;

if (strcmp(key, "default") == 0) {
printf("default detected\n");
}
return SQLITE_OK;
}
// UPDATE operation
Expand Down Expand Up @@ -880,6 +893,7 @@ __declspec(dllexport)
int nArg;
} aFuncApi[] = {
// clang-format off
{"lembed", lembed, 1},
{"lembed", lembed, 2},
{"lembed_tokenize_json", lembed_tokenize_json, 2},
{"lembed_token_score", lembed_token_score, 2},
Expand Down
2 changes: 1 addition & 1 deletion test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ with models as (
)
select
name,
lembed_model_from_file(model_path),`
lembed_model_from_file(model_path),
lembed_model_options(
'n_gpu_layers', 99
),
Expand Down
228 changes: 228 additions & 0 deletions tests/test-loadable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# ruff: noqa: E731
import struct
import re
import pytest
import sqlite3
import inspect
from contextlib import contextmanager

EXT_PATH = "./dist/lembed0"
MODEL1_PATH = "./dist/.models/all-MiniLM-L6-v2.e4ce9877.q8_0.gguf"


def connect(ext, path=":memory:", extra_entrypoint=None):
db = sqlite3.connect(path)

db.execute(
"create temp table base_functions as select name from pragma_function_list"
)
db.execute("create temp table base_modules as select name from pragma_module_list")

db.enable_load_extension(True)
db.load_extension(ext)

if extra_entrypoint:
db.execute("select load_extension(?, ?)", [ext, extra_entrypoint])

db.execute(
"create temp table loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name"
)
db.execute(
"create temp table loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name"
)

db.row_factory = sqlite3.Row
return db


db = connect(EXT_PATH)


def explain_query_plan(sql):
return db.execute("explain query plan " + sql).fetchone()["detail"]


def execute_all(cursor, sql, args=None):
if args is None:
args = []
results = cursor.execute(sql, args).fetchall()
return list(map(lambda x: dict(x), results))


def spread_args(args):
return ",".join(["?"] * len(args))


@contextmanager
def _raises(message, error=sqlite3.OperationalError):
with pytest.raises(error, match=re.escape(message)):
yield


FUNCTIONS = [
"_lembed_api",
"lembed",
"lembed",
"lembed_context_options",
"lembed_debug",
"lembed_model_from_file",
"lembed_model_options",
"lembed_model_size",
"lembed_token_score",
"lembed_token_to_piece",
"lembed_tokenize_json",
"lembed_version",
]
MODULES = [
"lembed_chunks",
"lembed_models",
]


def test_funcs():
funcs = list(
map(
lambda a: a[0],
db.execute("select name from loaded_functions").fetchall(),
)
)
assert funcs == FUNCTIONS


def test_modules():
modules = list(
map(lambda a: a[0], db.execute("select name from loaded_modules").fetchall())
)
assert modules == MODULES


def test_lembed_version():
lembed_version = lambda *args: db.execute(
"select lembed_version()", args
).fetchone()[0]
assert lembed_version()[0] == "v"


def test_lembed_debug():
lembed_debug = lambda *args: db.execute("select lembed_debug()", args).fetchone()[0]
d = lembed_debug().split("\n")
assert len(d) == 4


def test_lembed():
db.execute(
"insert into temp.lembed_models(name, model) values (?, lembed_model_from_file(?))",
["aaa", MODEL1_PATH],
)
lembed = lambda *args: db.execute(
"select lembed({})".format(spread_args(args)), args
).fetchone()[0]
a = lembed("aaa", "alex garcia")
assert len(a) == (384 * 4)
assert struct.unpack("1f", a[0:4])[0] == -0.09205757826566696

with _raises(
"Unknown model name 'aaaaaaaaa'. Was it registered with lembed_models?"
):
lembed("aaaaaaaaa", "alex garcia")

with _raises("No default model has been registered yet with lembed_models"):
lembed("alex garcia")

db.execute(
"insert into temp.lembed_models(name, model) values (?, lembed_model_from_file(?))",
["default", MODEL1_PATH],
)
a = lembed("alex garcia")
assert len(a) == (384 * 4)
assert struct.unpack("1f", a[0:4])[0] == -0.09205757826566696


@pytest.mark.skip(reason="TODO")
def test__lembed_api():
_lembed_api = lambda *args: db.execute("select _lembed_api()", args).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_context_options():
lembed_context_options = lambda *args: db.execute(
"select lembed_context_options()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_model_size():
lembed_model_size = lambda *args: db.execute(
"select lembed_model_size()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_model_from_file():
lembed_model_from_file = lambda *args: db.execute(
"select lembed_model_from_file()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_model_options():
lembed_model_options = lambda *args: db.execute(
"select lembed_model_options()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_tokenize_json():
lembed_tokenize_json = lambda *args: db.execute(
"select lembed_tokenize_json()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_token_score():
lembed_token_score = lambda *args: db.execute(
"select lembed_token_score()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_token_to_piece():
lembed_token_to_piece = lambda *args: db.execute(
"select lembed_token_to_piece()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_chunks():
lembed_chunks = lambda *args: db.execute(
"select * from lembed_chunks()", args
).fetchone()[0]
pass


@pytest.mark.skip(reason="TODO")
def test_lembed_models():
lembed_models = lambda *args: db.execute(
"select * from lembed_chunks()", args
).fetchone()[0]
pass


def test_coverage():
current_module = inspect.getmodule(inspect.currentframe())
test_methods = [
member[0]
for member in inspect.getmembers(current_module)
if member[0].startswith("test_")
]
funcs_with_tests = set([x.replace("test_", "") for x in test_methods])
for func in [*FUNCTIONS, *MODULES]:
assert func in funcs_with_tests, f"{func} is not tested"

0 comments on commit 1d7de47

Please sign in to comment.