Skip to content

Commit

Permalink
Check if covariance and mapping keys match when loading state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 23, 2024
1 parent d5cecfc commit fb6ac4b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
17 changes: 17 additions & 0 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,23 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self._N_data = state_dict["num_data"]

# Set Kronecker factors (if computed)
if self._input_covariances or self._gradient_covariances:
# If computed, check if the keys match the mapping keys
input_covariances_keys = set(self._input_covariances.keys())
gradient_covariances_keys = set(self._gradient_covariances.keys())
mapping_keys = set(self._mapping.keys())
if (
input_covariances_keys != mapping_keys
or gradient_covariances_keys != mapping_keys
):
raise ValueError(
"Input or gradient covariance keys in state dict do not match "
"mapping keys of linear operator. "
"Difference between input covariance and mapping keys: "
f"{input_covariances_keys - mapping_keys}. "
"Difference between gradient covariance and mapping keys: "
f"{gradient_covariances_keys - mapping_keys}."
)
self._input_covariances = state_dict["input_covariances"]
self._gradient_covariances = state_dict["gradient_covariances"]

Expand Down
4 changes: 2 additions & 2 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,11 @@ def test_KFAC_inverse_save_and_load_state_dict():
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
with raises(ValueError, match="mismatch"):
inv_kfac_wrong.load_state_dict(torch.load(state_dict))
inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt"))

# create new inverse KFAC and load state dict
inv_kfac_new = KFACInverseLinearOperator(kfac)
inv_kfac_new.load_state_dict(torch.load(state_dict))
inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt"))

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
Expand Down
8 changes: 4 additions & 4 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ def test_save_and_load_state_dict():
[(X, y)],
)
with raises(ValueError, match="loss"):
kfac_new.load_state_dict(load(state_dict))
kfac_new.load_state_dict(load("kfac_state_dict.pt"))

# create new KFAC with different loss reduction and try to load state dict
kfac_new = KFACLinearOperator(
Expand All @@ -1263,7 +1263,7 @@ def test_save_and_load_state_dict():
[(X, y)],
)
with raises(ValueError, match="reduction"):
kfac_new.load_state_dict(load(state_dict))
kfac_new.load_state_dict(load("kfac_state_dict.pt"))

# create new KFAC with different model and try to load state dict
wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out))
Expand All @@ -1276,7 +1276,7 @@ def test_save_and_load_state_dict():
loss_average=None,
)
with raises(RuntimeError, match="loading state_dict"):
kfac_new.load_state_dict(load(state_dict))
kfac_new.load_state_dict(load("kfac_state_dict.pt"))

# create new KFAC and load state dict
kfac_new = KFACLinearOperator(
Expand All @@ -1287,7 +1287,7 @@ def test_save_and_load_state_dict():
loss_average=None,
check_deterministic=False, # turn off to avoid computing KFAC again
)
kfac_new.load_state_dict(load(state_dict))
kfac_new.load_state_dict(load("kfac_state_dict.pt"))

# check that the two KFACs are equal
assert len(kfac.state_dict()) == len(kfac_new.state_dict())
Expand Down

0 comments on commit fb6ac4b

Please sign in to comment.