From 9bd6c31ba933800c8124e5f77f6ebde02854292e Mon Sep 17 00:00:00 2001 From: Florents Tselai Date: Sun, 23 Jun 2024 19:24:32 +0300 Subject: [PATCH] Use llm-markov for testing --- setup.py | 10 ++- tests/conftest.py | 153 ++++++++++++++++++++++++++++++++++++++----- tests/test_tsellm.py | 36 ++++------ tsellm/__init__.py | 16 ----- tsellm/cli.py | 3 +- tsellm/core.py | 16 +++++ 6 files changed, 171 insertions(+), 63 deletions(-) create mode 100644 tsellm/core.py diff --git a/setup.py b/setup.py index b95a8ac..d1bbdc1 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,15 @@ def get_long_description(): with open( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"), - encoding="utf8", + os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"), + encoding="utf8", ) as fp: return fp.read() setup( name="tsellm", - description="LLM support in SQLite", + description="Interactive SQLite shell with LLM support", long_description=get_long_description(), long_description_content_type="text/markdown", author="Florents Tselai", @@ -32,8 +32,6 @@ def get_long_description(): version=VERSION, packages=["tsellm"], install_requires=["llm", "setuptools", "pip"], - extras_require={ - "test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils"] - }, + extras_require={"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"]}, python_requires=">=3.7", ) diff --git a/tests/conftest.py b/tests/conftest.py index ffc8573..dc17987 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ from sqlite_utils import Database -from sqlite_utils.utils import sqlite3 import pytest +import pytest +import json +import llm +from llm.plugins import pm +from typing import Optional +import sqlite_utils +from pydantic import Field + def pytest_configure(config): import sys @@ -10,26 +17,142 @@ def pytest_configure(config): @pytest.fixture -def fresh_db(): - return Database(memory=True) +def user_path(tmpdir): + dir = tmpdir / "tsellm" + dir.mkdir() + return dir @pytest.fixture -def existing_db(db_path): - database = Database(db_path) - database.executescript( - """ - CREATE TABLE foo (text TEXT); - INSERT INTO foo (text) values ("one"); - INSERT INTO foo (text) values ("two"); - INSERT INTO foo (text) values ("three"); - """ - ) - return database +def logs_db(user_path): + return sqlite_utils.Database(str(user_path / "logs.db")) + + +@pytest.fixture +def user_path_with_embeddings(user_path): + path = str(user_path / "embeddings.db") + db = sqlite_utils.Database(path) + collection = llm.Collection("demo", db, model_id="embed-demo") + collection.embed("1", "hello world") + collection.embed("2", "goodbye world") + + +class MockModel(llm.Model): + model_id = "mock-echo" + + class Options(llm.Options): + max_tokens: Optional[int] = Field( + description="Maximum number of tokens to generate.", default=None + ) + + def __init__(self): + self.history = [] + self._queue = [] + + def enqueue(self, messages): + assert isinstance(messages, list) + self._queue.append(messages) + + def execute(self, prompt, stream, response, conversation): + self.history.append((prompt, stream, response, conversation)) + while True: + try: + messages = self._queue.pop(0) + yield from messages + break + except IndexError: + break + + +class EmbedDemo(llm.EmbeddingModel): + model_id = "embed-demo" + batch_size = 10 + supports_binary = True + + def __init__(self): + self.embedded_content = [] + + def embed_batch(self, texts): + if not hasattr(self, "batch_count"): + self.batch_count = 0 + self.batch_count += 1 + for text in texts: + self.embedded_content.append(text) + words = text.split()[:16] + embedding = [len(word) for word in words] + # Pad with 0 up to 16 words + embedding += [0] * (16 - len(embedding)) + yield embedding + + +class EmbedBinaryOnly(EmbedDemo): + model_id = "embed-binary-only" + supports_text = False + supports_binary = True + + +class EmbedTextOnly(EmbedDemo): + model_id = "embed-text-only" + supports_text = True + supports_binary = False + + +@pytest.fixture +def embed_demo(): + return EmbedDemo() + + +@pytest.fixture +def mock_model(): + return MockModel() + + +@pytest.fixture(autouse=True) +def register_embed_demo_model(embed_demo, mock_model): + class MockModelsPlugin: + __name__ = "MockModelsPlugin" + + @llm.hookimpl + def register_embedding_models(self, register): + register(embed_demo) + register(EmbedBinaryOnly()) + register(EmbedTextOnly()) + + @llm.hookimpl + def register_models(self, register): + register(mock_model) + + pm.register(MockModelsPlugin(), name="undo-mock-models-plugin") + try: + yield + finally: + pm.unregister(name="undo-mock-models-plugin") @pytest.fixture def db_path(tmpdir): path = str(tmpdir / "test.db") - db = sqlite3.connect(path) return path + + +@pytest.fixture +def fresh_db_path(db_path): + return db_path + +@pytest.fixture +def existing_db_path(fresh_db_path): + db = Database(fresh_db_path) + table = db.create_table( + "prompts", + { + "prompt": str, + "generated": str, + "model": str, + "embedding": dict, + }, + ) + + table.insert({"prompt": "hello world!"}) + table.insert({"prompt": "how are you?"}) + table.insert({"prompt": "is this real life?"}) + return fresh_db_path diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index e2f27da..5f87e7d 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -2,32 +2,20 @@ from tsellm.cli import cli -def test_tsellm_cli(db_path): - db = Database(db_path) - assert [] == db.table_names() - table = db.create_table( - "prompts", - { - "prompt": str, - "generated": str, - "model": str, - "embedding": dict, - }, - ) - - assert ["prompts"] == db.table_names() - - table.insert({"prompt": "hello"}) - table.insert({"prompt": "world"}) +def test_cli_prompt_mock(existing_db_path): + db = Database(existing_db_path) assert db.execute("select prompt from prompts").fetchall() == [ - ("hello",), - ("world",), + ("hello world!",), + ("how are you?",), + ("is this real life?",), ] - cli([db_path, "UPDATE prompts SET generated=prompt(prompt)"]) + cli([existing_db_path, "UPDATE prompts SET generated=prompt(prompt, 'markov')"]) - assert db.execute("select prompt, generated from prompts").fetchall() == [ - ("hello", "hellohello"), - ("world", "worldworld"), - ] + for prompt, generated in db.execute("select prompt, generated from prompts").fetchall(): + words = generated.strip().split() + # Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20) + prompt_words = prompt.split() + for word in words: + assert word in prompt_words diff --git a/tsellm/__init__.py b/tsellm/__init__.py index 3da56b1..e69de29 100644 --- a/tsellm/__init__.py +++ b/tsellm/__init__.py @@ -1,16 +0,0 @@ -def _prompt(p): - return p * 2 - - -TSELLM_CONFIG_SQL = """ -CREATE TABLE IF NOT EXISTS __tsellm ( -data -); - -""" - - -def _tsellm_init(con): - """Entry-point for tsellm initialization.""" - con.execute(TSELLM_CONFIG_SQL) - con.create_function("prompt", 1, _prompt) diff --git a/tsellm/cli.py b/tsellm/cli.py index 2926826..055d7ba 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser from code import InteractiveConsole from textwrap import dedent -from . import _prompt, _tsellm_init +from .core import _tsellm_init def execute(c, sql, suppress_errors=True): @@ -58,7 +58,6 @@ def runsource(self, source, filename="", symbol="single"): def cli(*args): - print(args) parser = ArgumentParser( description="tsellm sqlite3 CLI", prog="python -m tsellm", diff --git a/tsellm/core.py b/tsellm/core.py new file mode 100644 index 0000000..acfc0a5 --- /dev/null +++ b/tsellm/core.py @@ -0,0 +1,16 @@ +import llm + +TSELLM_CONFIG_SQL = """ +CREATE TABLE IF NOT EXISTS __tsellm ( +x text +); + +""" + +def _prompt_model(prompt, model): + return llm.get_model(model).prompt(prompt).text() + +def _tsellm_init(con): + """Entry-point for tsellm initialization.""" + con.execute(TSELLM_CONFIG_SQL) + con.create_function("prompt", 2, _prompt_model)