Skip to content

Commit

Permalink
Tests for the Storage class
Browse files Browse the repository at this point in the history
  • Loading branch information
goldpulpy committed Oct 9, 2024
1 parent aac646c commit 0c4962b
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Test cases for the Storage class."""
import unittest
import os
import numpy as np
from pysentence_similarity._storage import Storage, InvalidDataError


class TestStorage(unittest.TestCase):
"""Test cases for the Storage class."""

def setUp(self) -> None:
"""Set up test data before each test."""
self.sentences = [
"This is a test sentence.",
"This is another sentence."
]
self.embeddings = [np.random.rand(3), np.random.rand(3)]
self.storage = Storage(
sentences=self.sentences,
embeddings=self.embeddings
)
self.test_filename = "test_storage.h5"

def tearDown(self) -> None:
"""Remove test file after each test."""
if os.path.exists(self.test_filename):
os.remove(self.test_filename)

def test_initialization_valid(self) -> None:
"""Test initialization with valid data."""
self.assertEqual(len(self.storage), 2)
self.assertEqual(self.storage._sentences, self.sentences)

def test_initialization_invalid_sentences(self) -> None:
"""Test initialization with invalid sentences."""
with self.assertRaises(InvalidDataError):
Storage(sentences="Not a list", embeddings=self.embeddings)

def test_initialization_invalid_embeddings(self) -> None:
"""Test initialization with invalid embeddings."""
with self.assertRaises(InvalidDataError):
Storage(sentences=self.sentences, embeddings="Not a list")

def test_save_and_load(self) -> None:
"""Test saving and loading data."""
self.storage.save(self.test_filename)
loaded_storage = Storage.load(self.test_filename)
self.assertEqual(len(loaded_storage), 2)
self.assertEqual(loaded_storage._sentences, self.sentences)

def test_save_invalid_data(self) -> None:
"""Test saving with invalid data."""
self.storage._sentences.append(123)
with self.assertRaises(InvalidDataError):
self.storage.save(self.test_filename)

def test_add_sentences_and_embeddings(self) -> None:
"""Test adding valid sentences and embeddings."""
new_sentence = "This is a new sentence."
new_embedding = np.random.rand(3)
self.storage.add(new_sentence, new_embedding)

self.assertEqual(len(self.storage), 3)
self.assertEqual(self.storage._sentences[-1], new_sentence)

def test_add_and_save(self) -> None:
"""Test adding and saving."""
new_sentence = "This is another new sentence."
new_embedding = np.random.rand(3)
self.storage.add(
new_sentence,
new_embedding,
save=True,
filename=self.test_filename
)

loaded_storage = Storage.load(self.test_filename)
self.assertEqual(len(loaded_storage), 3)
self.assertEqual(loaded_storage._sentences[-1], new_sentence)

def test_index_out_of_range(self) -> None:
"""Test index out of range."""
with self.assertRaises(IndexError):
self.storage[10]


if __name__ == "__main__":
unittest.main()

0 comments on commit 0c4962b

Please sign in to comment.