Skip to content

Commit

Permalink
Add test for issue #118
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jun 12, 2024
1 parent 13b1082 commit 441fcf3
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,3 +1270,21 @@ def test_from_state_dict():
compare_state_dicts(kfac.state_dict(), kfac_new.state_dict())
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)


@mark.parametrize("fisher_type", ["type-2", "mc", "empirical", "forward-only"])
@mark.parametrize("kfac_approx", ["expand", "reduce"])
def test_string_in_enum(fisher_type: str, kfac_approx: str):
"""Test whether checking if a string is contained in enum works.
To reproduce issue #118.
"""
model = Linear(2, 2)
KFACLinearOperator(
model,
MSELoss(),
list(model.parameters()),
[(rand(2, 2), rand(2, 2))],
fisher_type=fisher_type,
kfac_approx=kfac_approx,
)

0 comments on commit 441fcf3

Please sign in to comment.