Skip to content

Commit

Permalink
Use llm-markov for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai committed Jun 23, 2024
1 parent 30bf219 commit 9bd6c31
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 63 deletions.
10 changes: 4 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

def get_long_description():
with open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
encoding="utf8",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
encoding="utf8",
) as fp:
return fp.read()


setup(
name="tsellm",
description="LLM support in SQLite",
description="Interactive SQLite shell with LLM support",
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Florents Tselai",
Expand All @@ -32,8 +32,6 @@ def get_long_description():
version=VERSION,
packages=["tsellm"],
install_requires=["llm", "setuptools", "pip"],
extras_require={
"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils"]
},
extras_require={"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"]},
python_requires=">=3.7",
)
153 changes: 138 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from sqlite_utils import Database
from sqlite_utils.utils import sqlite3
import pytest

import pytest
import json
import llm
from llm.plugins import pm
from typing import Optional
import sqlite_utils
from pydantic import Field


def pytest_configure(config):
import sys
Expand All @@ -10,26 +17,142 @@ def pytest_configure(config):


@pytest.fixture
def fresh_db():
return Database(memory=True)
def user_path(tmpdir):
dir = tmpdir / "tsellm"
dir.mkdir()
return dir


@pytest.fixture
def existing_db(db_path):
database = Database(db_path)
database.executescript(
"""
CREATE TABLE foo (text TEXT);
INSERT INTO foo (text) values ("one");
INSERT INTO foo (text) values ("two");
INSERT INTO foo (text) values ("three");
"""
)
return database
def logs_db(user_path):
return sqlite_utils.Database(str(user_path / "logs.db"))


@pytest.fixture
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection("demo", db, model_id="embed-demo")
collection.embed("1", "hello world")
collection.embed("2", "goodbye world")


class MockModel(llm.Model):
model_id = "mock-echo"

class Options(llm.Options):
max_tokens: Optional[int] = Field(
description="Maximum number of tokens to generate.", default=None
)

def __init__(self):
self.history = []
self._queue = []

def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)

def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
while True:
try:
messages = self._queue.pop(0)
yield from messages
break
except IndexError:
break


class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
supports_binary = True

def __init__(self):
self.embedded_content = []

def embed_batch(self, texts):
if not hasattr(self, "batch_count"):
self.batch_count = 0
self.batch_count += 1
for text in texts:
self.embedded_content.append(text)
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
yield embedding


class EmbedBinaryOnly(EmbedDemo):
model_id = "embed-binary-only"
supports_text = False
supports_binary = True


class EmbedTextOnly(EmbedDemo):
model_id = "embed-text-only"
supports_text = True
supports_binary = False


@pytest.fixture
def embed_demo():
return EmbedDemo()


@pytest.fixture
def mock_model():
return MockModel()


@pytest.fixture(autouse=True)
def register_embed_demo_model(embed_demo, mock_model):
class MockModelsPlugin:
__name__ = "MockModelsPlugin"

@llm.hookimpl
def register_embedding_models(self, register):
register(embed_demo)
register(EmbedBinaryOnly())
register(EmbedTextOnly())

@llm.hookimpl
def register_models(self, register):
register(mock_model)

pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
try:
yield
finally:
pm.unregister(name="undo-mock-models-plugin")


@pytest.fixture
def db_path(tmpdir):
path = str(tmpdir / "test.db")
db = sqlite3.connect(path)
return path


@pytest.fixture
def fresh_db_path(db_path):
return db_path

@pytest.fixture
def existing_db_path(fresh_db_path):
db = Database(fresh_db_path)
table = db.create_table(
"prompts",
{
"prompt": str,
"generated": str,
"model": str,
"embedding": dict,
},
)

table.insert({"prompt": "hello world!"})
table.insert({"prompt": "how are you?"})
table.insert({"prompt": "is this real life?"})
return fresh_db_path
36 changes: 12 additions & 24 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,20 @@
from tsellm.cli import cli


def test_tsellm_cli(db_path):
db = Database(db_path)
assert [] == db.table_names()
table = db.create_table(
"prompts",
{
"prompt": str,
"generated": str,
"model": str,
"embedding": dict,
},
)

assert ["prompts"] == db.table_names()

table.insert({"prompt": "hello"})
table.insert({"prompt": "world"})
def test_cli_prompt_mock(existing_db_path):
db = Database(existing_db_path)

assert db.execute("select prompt from prompts").fetchall() == [
("hello",),
("world",),
("hello world!",),
("how are you?",),
("is this real life?",),
]

cli([db_path, "UPDATE prompts SET generated=prompt(prompt)"])
cli([existing_db_path, "UPDATE prompts SET generated=prompt(prompt, 'markov')"])

assert db.execute("select prompt, generated from prompts").fetchall() == [
("hello", "hellohello"),
("world", "worldworld"),
]
for prompt, generated in db.execute("select prompt, generated from prompts").fetchall():
words = generated.strip().split()
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
prompt_words = prompt.split()
for word in words:
assert word in prompt_words
16 changes: 0 additions & 16 deletions tsellm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +0,0 @@
def _prompt(p):
return p * 2


TSELLM_CONFIG_SQL = """
CREATE TABLE IF NOT EXISTS __tsellm (
data
);
"""


def _tsellm_init(con):
"""Entry-point for tsellm initialization."""
con.execute(TSELLM_CONFIG_SQL)
con.create_function("prompt", 1, _prompt)
3 changes: 1 addition & 2 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from argparse import ArgumentParser
from code import InteractiveConsole
from textwrap import dedent
from . import _prompt, _tsellm_init
from .core import _tsellm_init


def execute(c, sql, suppress_errors=True):
Expand Down Expand Up @@ -58,7 +58,6 @@ def runsource(self, source, filename="<input>", symbol="single"):


def cli(*args):
print(args)
parser = ArgumentParser(
description="tsellm sqlite3 CLI",
prog="python -m tsellm",
Expand Down
16 changes: 16 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import llm

TSELLM_CONFIG_SQL = """
CREATE TABLE IF NOT EXISTS __tsellm (
x text
);
"""

def _prompt_model(prompt, model):
return llm.get_model(model).prompt(prompt).text()

def _tsellm_init(con):
"""Entry-point for tsellm initialization."""
con.execute(TSELLM_CONFIG_SQL)
con.create_function("prompt", 2, _prompt_model)

0 comments on commit 9bd6c31

Please sign in to comment.