Skip to content

Commit

Permalink
Tests: add partial_name_map_composite fixture
Browse files Browse the repository at this point in the history
- the partial_name_map_composites fixture creates a random (but seeded)
  partial name_map by the actual associations (i.e. one name to one
  rule)
- this is used to correctly test MixedComposites/NameLayerMapComposites
  to fall back to the LayerMap
- add pyrng fixture, based on the same seed as the numpy RNG
- removed unneeded input fixtures for some fixtures in conftest.py
  • Loading branch information
chr5tphr committed Feb 16, 2023
1 parent cbfad69 commit 3121bb2
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'''Configuration and fixtures for testing'''
from itertools import product
import random
from itertools import product, groupby
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -50,10 +51,16 @@ def pytest_generate_tests(metafunc):
ids=hex
)
def rng(request):
'''Random number generator fixture.'''
'''Fixture for the NumPy random number generator.'''
return torch.manual_seed(request.param)


@pytest.fixture(scope='session')
def pyrng(rng):
'''Fixture for the Python random number generator.'''
return random.Random(rng.initial_seed())


@pytest.fixture(
scope='session',
params=[
Expand Down Expand Up @@ -226,7 +233,7 @@ def any_composite(request):


@pytest.fixture(scope='session')
def name_map_composite(request, model_vision, layer_map_composite):
def name_map_composite(model_vision, layer_map_composite):
'''Fixture to create NameMapComposites based on explicit LayerMapComposites.'''
rule_map = {}
for name, child in model_vision.named_modules():
Expand All @@ -239,19 +246,33 @@ def name_map_composite(request, model_vision, layer_map_composite):


@pytest.fixture(scope='session')
def mixed_composite(request, model_vision, name_map_composite, special_first_layer_map_composite):
'''Fixture to create NameLayerMapComposites based on an explicit
NameMapComposite and SpecialFirstLayerMapComposites.'''
composites = [name_map_composite, special_first_layer_map_composite]
def partial_name_map_composite(name_map_composite, pyrng):
'''Fixture to create a randomly sampled partial NameMapComposites.'''
name_map = name_map_composite.name_map
assocs = [(i, j) for i, (keys, _) in enumerate(name_map) for j in range(len(keys))]
accepted_assocs = sorted(pyrng.sample(assocs, len(assocs) // 2))
partial_name_map = [
(tuple(name_map[k][0][n] for _, n in g), name_map[k][1].copy())
for k, g in groupby(accepted_assocs, lambda o: o[0])
]

return NameMapComposite(name_map=partial_name_map)


@pytest.fixture(scope='session')
def mixed_composite(partial_name_map_composite, special_first_layer_map_composite):
'''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and
SpecialFirstLayerMapComposites.
'''
composites = [partial_name_map_composite, special_first_layer_map_composite]
return MixedComposite(composites)


@pytest.fixture(scope='session')
def name_layer_map_composite(request, model_vision, name_map_composite, layer_map_composite):
'''Fixture to create NameLayerMapComposites based on an explicit
NameMapComposite and LayerMapComposite.'''
def name_layer_map_composite(partial_name_map_composite, layer_map_composite):
'''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and LayerMapComposite.'''
return NameLayerMapComposite(
name_map=name_map_composite.name_map,
name_map=partial_name_map_composite.name_map,
layer_map=layer_map_composite.layer_map,
)

Expand Down

0 comments on commit 3121bb2

Please sign in to comment.