Skip to content

Commit

Permalink
Add MetaEnum to get desired behavior of in-operator
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jun 12, 2024
1 parent 441fcf3 commit 188f84e
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from collections.abc import MutableMapping
from enum import Enum
from enum import Enum, EnumMeta
from functools import partial
from math import sqrt
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -52,7 +52,18 @@
ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor])


class FisherType(str, Enum):
class MetaEnum(EnumMeta):
"""Metaclass for the Enum class for desired behavior of the `in` operator."""

def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True


class FisherType(str, Enum, metaclass=MetaEnum):
"""Enum for the Fisher type."""

TYPE2 = "type-2"
Expand All @@ -61,7 +72,7 @@ class FisherType(str, Enum):
FORWARD_ONLY = "forward-only"


class KFACType(str, Enum):
class KFACType(str, Enum, metaclass=MetaEnum):
"""Enum for the KFAC approximation type."""

EXPAND = "expand"
Expand Down

0 comments on commit 188f84e

Please sign in to comment.