diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index a598c5d..42c5bf9 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -60,6 +60,7 @@ class MetaEnum(EnumMeta): """Metaclass for the Enum class for desired behavior of the `in` operator.""" def __contains__(cls, item): + """Check if an item is a valid member of the Enum.""" try: cls(item) except ValueError: diff --git a/test/test_kfac.py b/test/test_kfac.py index 13d8f63..98426d6 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -412,7 +412,7 @@ def test_kfac_ef_one_datum( ) kfac_mat = kfac @ eye(kfac.shape[1]) - report_nonclose(ef, kfac_mat) + report_nonclose(ef, kfac_mat, atol=1e-7) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)