diff --git a/test/test_inverse.py b/test/test_inverse.py index 6299355..f064ccc 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -737,5 +737,6 @@ def test_torch_matvec_list_output_shapes(cnn_case): inv_kfac = KFACInverseLinearOperator(kfac, damping=1e-2) vec = [torch.rand_like(p) for p in kfac._params] out_list = inv_kfac.torch_matvec(vec) - for out_i, p_i in zip(out_list, kfac._params, strict=True): + assert len(out_list) == len(kfac._params) + for out_i, p_i in zip(out_list, kfac._params): assert out_i.shape == p_i.shape diff --git a/test/test_kfac.py b/test/test_kfac.py index b085f82..e7f8d61 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -695,7 +695,8 @@ def test_torch_matvec_list_output_shapes(cnn_case): ) vec = [rand_like(p) for p in kfac._params] out_list = kfac.torch_matvec(vec) - for out_i, p_i in zip(out_list, kfac._params, strict=True): + assert len(out_list) == len(kfac._params) + for out_i, p_i in zip(out_list, kfac._params): assert out_i.shape == p_i.shape