Skip to content

Commit

Permalink
Only cache weight shape when they exist
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Aug 7, 2024
1 parent 8baa121 commit 141bf72
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
10 changes: 6 additions & 4 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
# cache the weight shape to ensure correct shapes are returned
weight_shape = M_torch[param_pos["weight"]].shape
if "weight" in param_pos:
weight_shape = M_torch[param_pos["weight"]].shape

# bias and weights are treated jointly
if (
Expand All @@ -648,9 +649,10 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
)

# restore original shapes
M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view(
weight_shape
)
if "weight" in param_pos:
M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view(
weight_shape
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch])
Expand Down
10 changes: 6 additions & 4 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:

for mod_name, param_pos in self._mapping.items():
# cache the weight shape to ensure correct shapes are returned
weight_shape = M_torch[param_pos["weight"]].shape
if "weight" in param_pos:
weight_shape = M_torch[param_pos["weight"]].shape

# bias and weights are treated jointly
if (
Expand Down Expand Up @@ -472,9 +473,10 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
)

# restore original shapes
M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view(
weight_shape
)
if "weight" in param_pos:
M_torch[param_pos["weight"]] = M_torch[param_pos["weight"]].view(
weight_shape
)

if return_tensor:
M_torch = cat([rearrange(M, "k ... -> (...) k") for M in M_torch])
Expand Down

0 comments on commit 141bf72

Please sign in to comment.