diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 0f8cbb3..9f4fafa 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -12,7 +12,7 @@ NeumannInverseLinearOperator, ) from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator -from curvlinops.kfac import KFACLinearOperator +from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType from curvlinops.norm.hutchinson import HutchinsonSquaredFrobeniusNormEstimator from curvlinops.papyan2020traces.spectrum import ( LanczosApproximateLogSpectrumCached, @@ -33,6 +33,9 @@ "KFACLinearOperator", "JacobianLinearOperator", "TransposedJacobianLinearOperator", + # Enums + "FisherType", + "KFACType", # inversion "CGInverseLinearOperator", "LSMRInverseLinearOperator", diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 19f664e..fbee73b 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -19,7 +19,7 @@ from __future__ import annotations from collections.abc import MutableMapping -from enum import Enum +from enum import Enum, EnumMeta from functools import partial from math import sqrt from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union @@ -52,7 +52,18 @@ ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor]) -class FisherType(str, Enum): +class MetaEnum(EnumMeta): + """Metaclass for the Enum class for desired behavior of the `in` operator.""" + + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True + + +class FisherType(str, Enum, metaclass=MetaEnum): """Enum for the Fisher type.""" TYPE2 = "type-2" @@ -61,7 +72,7 @@ class FisherType(str, Enum): FORWARD_ONLY = "forward-only" -class KFACType(str, Enum): +class KFACType(str, Enum, metaclass=MetaEnum): """Enum for the KFAC approximation type.""" EXPAND = "expand" diff --git a/test/test_kfac.py b/test/test_kfac.py index 1ddb0e2..99f5993 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1270,3 +1270,21 @@ def test_from_state_dict(): compare_state_dicts(kfac.state_dict(), kfac_new.state_dict()) test_vec = rand(kfac.shape[1]) report_nonclose(kfac @ test_vec, kfac_new @ test_vec) + + +@mark.parametrize("fisher_type", ["type-2", "mc", "empirical", "forward-only"]) +@mark.parametrize("kfac_approx", ["expand", "reduce"]) +def test_string_in_enum(fisher_type: str, kfac_approx: str): + """Test whether checking if a string is contained in enum works. + + To reproduce issue #118. + """ + model = Linear(2, 2) + KFACLinearOperator( + model, + MSELoss(), + list(model.parameters()), + [(rand(2, 2), rand(2, 2))], + fisher_type=fisher_type, + kfac_approx=kfac_approx, + )