From 141bf7276da7a2f3b1bc7c75ebcec11bf2e97309 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 7 Aug 2024 12:21:58 -0400 Subject: [PATCH] Only cache weight shape when they exist --- curvlinops/inverse.py | 10 ++++++---- curvlinops/kfac.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 4d67fd3..cd5482b 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -628,7 +628,8 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: # retrieve the inverses of the Kronecker factors from cache or invert them aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name) # cache the weight shape to ensure correct shapes are returned - weight_shape = M_torch[param_pos["weight"]].shape + if "weight" in param_pos: + weight_shape = M_torch[param_pos["weight"]].shape # bias and weights are treated jointly if ( @@ -648,9 +649,10 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ) # restore original shapes - M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view( - weight_shape - ) + if "weight" in param_pos: + M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view( + weight_shape + ) if return_tensor: M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch]) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 0a79722..413c86b 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -434,7 +434,8 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: for mod_name, param_pos in self._mapping.items(): # cache the weight shape to ensure correct shapes are returned - weight_shape = M_torch[param_pos["weight"]].shape + if "weight" in param_pos: + weight_shape = M_torch[param_pos["weight"]].shape # bias and weights are treated jointly if ( @@ -472,9 +473,10 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ) # restore original shapes - M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view( - weight_shape - ) + if "weight" in param_pos: + M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view( + weight_shape + ) if return_tensor: M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch])