Skip to content

Commit

Permalink
tests for yi, stable diffusion, timm models, etc
Browse files Browse the repository at this point in the history
Former-commit-id: dfea671
  • Loading branch information
kyegomez committed Nov 14, 2023
1 parent 59f3b4c commit 48643b3
Show file tree
Hide file tree
Showing 10 changed files with 1,024 additions and 124 deletions.
2 changes: 1 addition & 1 deletion swarms/models/autotemp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from swarms.models.auto_temp import OpenAIChat
from swarms.models.openai_models import OpenAIChat


class AutoTempAgent:
Expand Down
1 change: 1 addition & 0 deletions swarms/models/simple_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

client = OpenAI()


def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
"""
Simple function to get embeddings from ada
Expand Down
161 changes: 161 additions & 0 deletions tests/models/bioclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Import necessary modules and define fixtures if needed
import os
import pytest
import torch
from PIL import Image
from swarms.models.bioclip import BioClip


# Define fixtures if needed
@pytest.fixture
def sample_image_path():
return "path_to_sample_image.jpg"


@pytest.fixture
def clip_instance():
return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")


# Basic tests for the BioClip class
def test_clip_initialization(clip_instance):
assert isinstance(clip_instance.model, torch.nn.Module)
assert hasattr(clip_instance, "model_path")
assert hasattr(clip_instance, "preprocess_train")
assert hasattr(clip_instance, "preprocess_val")
assert hasattr(clip_instance, "tokenizer")
assert hasattr(clip_instance, "device")


def test_clip_call_method(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)


def test_clip_plot_image_with_metadata(clip_instance, sample_image_path):
metadata = {
"filename": "sample_image.jpg",
"top_probs": {"label1": 0.75, "label2": 0.65},
}
clip_instance.plot_image_with_metadata(sample_image_path, metadata)


# More test cases can be added to cover additional functionality and edge cases


# Parameterized tests for different image and label combinations
@pytest.mark.parametrize(
"image_path, labels",
[
("image1.jpg", ["label1", "label2"]),
("image2.jpg", ["label3", "label4"]),
# Add more image and label combinations
],
)
def test_clip_parameterized_calls(clip_instance, image_path, labels):
result = clip_instance(image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)


# Test image preprocessing
def test_clip_image_preprocessing(clip_instance, sample_image_path):
image = Image.open(sample_image_path)
processed_image = clip_instance.preprocess_val(image)
assert isinstance(processed_image, torch.Tensor)


# Test label tokenization
def test_clip_label_tokenization(clip_instance):
labels = ["label1", "label2"]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)


# More tests can be added to cover other methods and edge cases


# End-to-end tests with actual images and labels
def test_clip_end_to_end(clip_instance, sample_image_path):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = clip_instance(sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)


# Test label tokenization with long labels
def test_clip_long_labels(clip_instance):
labels = ["label" + str(i) for i in range(100)]
tokenized_labels = clip_instance.tokenizer(labels)
assert isinstance(tokenized_labels, torch.Tensor)
assert tokenized_labels.shape[0] == len(labels)


# Test handling of multiple image files
def test_clip_multiple_images(clip_instance, sample_image_path):
labels = ["label1", "label2"]
image_paths = [sample_image_path, "image2.jpg"]
results = clip_instance(image_paths, labels)
assert isinstance(results, list)
assert len(results) == len(image_paths)
for result in results:
assert isinstance(result, dict)
assert len(result) == len(labels)


# Test model inference performance
def test_clip_inference_performance(clip_instance, sample_image_path, benchmark):
labels = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]
result = benchmark(clip_instance, sample_image_path, labels)
assert isinstance(result, dict)
assert len(result) == len(labels)


# Test different preprocessing pipelines
def test_clip_preprocessing_pipelines(clip_instance, sample_image_path):
labels = ["label1", "label2"]
image = Image.open(sample_image_path)

# Test preprocessing for training
processed_image_train = clip_instance.preprocess_train(image)
assert isinstance(processed_image_train, torch.Tensor)

# Test preprocessing for validation
processed_image_val = clip_instance.preprocess_val(image)
assert isinstance(processed_image_val, torch.Tensor)


# ...
118 changes: 114 additions & 4 deletions tests/models/distill_whisper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import tempfile
from functools import wraps
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch

import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

from swarms.models.distill_whisperx import DistilWhisperModel, async_retry
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry


@pytest.fixture
Expand Down Expand Up @@ -150,5 +151,114 @@ def test_create_audio_file():
os.remove(audio_file_path)


if __name__ == "__main__":
pytest.main()
# test_distilled_whisperx.py


# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"


@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)


@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"


@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"


@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}


# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None


# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)


# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)


# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)


# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(whisper_model, audio_file_path):
transcription = await whisper_model.async_transcribe(audio_file_path)
assert isinstance(transcription, str)


@pytest.mark.asyncio
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)


# Testing real-time transcription simulation
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out


# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")

with pytest.raises(ValueError):
await failing_func()


# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock
)
return model_mock, processor_mock


@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
model_mock.return_value.generate.return_value = torch.tensor([[0]])
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"]
model_wrapper = DistilWhisperModel()
transcription = await model_wrapper.async_transcribe(audio_file_path)
assert transcription == "mocked transcription"
Loading

0 comments on commit 48643b3

Please sign in to comment.