Skip to content

Commit

Permalink
test: expect new class RuntimeStopSignal, integrate with termination …
Browse files Browse the repository at this point in the history
…conditions & update regressions
  • Loading branch information
boromir674 committed Oct 30, 2024
1 parent 20877c9 commit 239c403
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 32 deletions.
30 changes: 21 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import typing as t
from pathlib import Path

import pytest

# runtime directory of the `tests` directory
tests_root: Path = Path(__file__).parent


@pytest.fixture
def test_suite():
Expand All @@ -12,6 +16,20 @@ def test_suite():
return os.path.dirname(os.path.realpath(__file__))


@pytest.fixture
def test_root() -> Path:
"""Root directory Path of the test suite; ie /projects/my-project/tests/"""
return tests_root


@pytest.fixture
def get_test_data_image() -> t.Callable[[str], Path]:
"""Get the path of an image file in the test data directory."""
def get_image_file_path(file_name: str):
return tests_root / "data" / file_name
return get_image_file_path


@pytest.fixture
def test_image(test_suite):
import os
Expand Down Expand Up @@ -86,13 +104,15 @@ def termination_condition_module():
TerminationConditionFacility,
TerminationConditionInterface,
TimeLimit,
RuntimeStopSignal,
)

# all tests require that the Facility already contains some implementations of TerminationCondition
# all tests require that the Facility already contains all implementations of TerminationCondition
assert TerminationConditionFacility.class_registry.subclasses == {
"max-iterations": MaxIterations,
"time-limit": TimeLimit,
"convergence": Convergence,
"stop-signal": RuntimeStopSignal
}
return type(
"M",
Expand Down Expand Up @@ -230,9 +250,6 @@ def vgg_layers():
return tuple((layer_id for _, layer_id in VGG_LAYERS))


PRODUCTION_IMAGE_MODEL = os.environ.get("AA_VGG_19", "PRETRAINED_MODEL_NOT_FOUND")


@pytest.fixture
def pre_trained_models_1(vgg_layers, toy_model_data, toy_network_design):
import typing as t
Expand Down Expand Up @@ -303,11 +320,6 @@ def model_routines(self):

@pytest.fixture
def model(pre_trained_models_1):
import os

print(f"\n -- PROD IM MODEL: {PRODUCTION_IMAGE_MODEL}")
print(f"Selected Prod?: {os.path.isfile(PRODUCTION_IMAGE_MODEL)}")

return pre_trained_models_1["toy"]
# return {
# True: pre_trained_models_1['vgg'],
Expand Down
30 changes: 13 additions & 17 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,23 @@ def create_max_iterations_termination_condition_adapter(iterations):


@pytest.fixture
def algorithm_parameters_class():
def algorithm():
from artificial_artwork.algorithm import AlogirthmParameters

return AlogirthmParameters


@pytest.fixture
def algorithm(algorithm_parameters_class):
from artificial_artwork.algorithm import NSTAlgorithm

def _create_algorithm(*parameters):
return NSTAlgorithm(algorithm_parameters_class(*parameters))
return NSTAlgorithm(AlogirthmParameters(*parameters))

return _create_algorithm


@pytest.fixture
def create_algorithm(algorithm, tmpdir):
def _create_algorithm(image_manager, termination_condition_adapter):
def _create_algorithm(image_manager, termination_condition_adapters):
return algorithm(
image_manager.content_image,
image_manager.style_image,
termination_condition_adapter,
termination_condition_adapters,
tmpdir,
)

Expand All @@ -82,13 +76,13 @@ def create_production_algorithm_runner():

noisy_ratio = 0.6

def _create_production_algorithm_runner(termination_condition_adapter, max_iterations):
def _create_production_algorithm_runner(termination_condition_adapters, max_iterations):
algorithm_runner = NSTAlgorithmRunner.default(
lambda matrix: noisy(matrix, noisy_ratio),
)

algorithm_runner.progress_subject.add(
termination_condition_adapter,
*termination_condition_adapters
)
algorithm_runner.persistance_subject.add(
StylingObserver(Disk.save_image, convert_to_uint8, max_iterations)
Expand All @@ -100,9 +94,9 @@ def _create_production_algorithm_runner(termination_condition_adapter, max_itera

@pytest.fixture
def get_algorithm_runner(create_production_algorithm_runner):
def _get_algorithm_runner(termination_condition_adapter, max_iterations):
def _get_algorithm_runner(termination_condition_adapters, max_iterations):
algorithm_runner = create_production_algorithm_runner(
termination_condition_adapter,
termination_condition_adapters,
max_iterations,
)
return algorithm_runner
Expand All @@ -128,7 +122,7 @@ def test_nst_runner(
max_iterations_adapter_factory_method,
image_manager,
test_image,
model,
model, # Tiny Neural Network with 1 Layer
tmpdir,
):
"""Test nst algorithm runner.
Expand All @@ -146,9 +140,11 @@ def test_nst_runner(

termination_condition_adapter = max_iterations_adapter_factory_method(ITERATIONS)

algorithm_runner = get_algorithm_runner(termination_condition_adapter, ITERATIONS)
termination_conditions = [termination_condition_adapter] # they implement listener interface

algorithm_runner = get_algorithm_runner(termination_conditions, ITERATIONS)

algorithm = create_algorithm(image_manager, termination_condition_adapter)
algorithm = create_algorithm(image_manager, termination_conditions)

model_design = get_model_design(
model.pretrained_model.handler,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_algorithm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ def algorithm_parameters():
return AlogirthmParameters(
"content_image",
"style_image",
"termination_condition",
[], # some list of termination conditions
"output_path",
)


def test_algorithm_parameters(algorithm_parameters):
assert hasattr(algorithm_parameters, "content_image")
assert hasattr(algorithm_parameters, "style_image")
assert hasattr(algorithm_parameters, "termination_condition")
assert hasattr(algorithm_parameters, "termination_conditions")
assert hasattr(algorithm_parameters, "output_path")
4 changes: 2 additions & 2 deletions tests/test_cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from artificial_artwork.production_networks import NetworkDesign


class TestPreTrainedModel(t.Protocol):
class PreTrainedModelForTesting(t.Protocol):
# property/attribute for doing assertions
expected_layers: t.Tuple[str] # ie (conv1_1, conv1_2, .., avgpool5)

Expand All @@ -18,7 +18,7 @@ class TestPreTrainedModel(t.Protocol):


class NSTTestingModelProtocol(t.Protocol):
pretrained_model: TestPreTrainedModel
pretrained_model: PreTrainedModelForTesting
network_design: NetworkDesign


Expand Down
14 changes: 12 additions & 2 deletions tests/test_image/test_image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,22 @@ def test_uint8_convertion(test_image, expected_image, image_operations):
assert np.nanmax(runtime_image) < np.power(2.0, 8)
assert runtime_image.tolist() == expected_image


@pytest.mark.parametrize(
"test_image",
[
([[np.nan, np.nan], [np.nan, np.nan]],),
],
)
def test_all_nan_image_raises_error(test_image, image_operations):
with pytest.warns(RuntimeWarning, match=r"All-NaN slice encountered"):
with pytest.raises(ValueError, match=r"Minimum image value is not finite"):
runtime_image = image_operations.convert_to_uint8(
np.array([[np.nan, np.nan], [np.nan, np.nan]], dtype=np.float32)
)

@pytest.mark.parametrize(
"test_image",
[
([[1, -float("inf")], [2, 3]],),
],
)
Expand All @@ -127,7 +138,6 @@ def test_non_finite_minimum_value(test_image, image_operations):
np.array(test_image, dtype=np.float32)
)


@pytest.mark.parametrize(
"test_image",
[
Expand Down

0 comments on commit 239c403

Please sign in to comment.