diff --git a/tests/unittest/test_utils.py b/tests/unittest/test_utils.py index 79e1ff7..64b0097 100644 --- a/tests/unittest/test_utils.py +++ b/tests/unittest/test_utils.py @@ -1,7 +1,8 @@ import numpy as np import torch -from mohou.utils import splitting_slices +from mohou.model.autoencoder import AutoEncoder, AutoEncoderBase, VariationalAutoEncoder +from mohou.utils import get_all_concrete_leaftypes, splitting_slices def test_splitting_slicers(): @@ -24,3 +25,9 @@ def test_splitting_slicers(): assert obj3[0][0][0] == 6 assert obj3[-1][0][0] == 9 + + +def test_get_all_concreate_leaftypes(): + types = get_all_concrete_leaftypes(AutoEncoderBase) + print(types) + assert set(types) == set([AutoEncoder, VariationalAutoEncoder])