Skip to content

Commit

Permalink
[FIX] Make tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 8, 2023
1 parent ca547e5 commit 2f3bdec
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 2 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _matvec(self, x: ndarray) -> ndarray:
mod.weight, mod.bias
):
w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias)
x_joint = cat([x_torch[w_pos], x_torch[b_pos]], dim=1)
x_joint = cat([x_torch[w_pos], x_torch[b_pos].unsqueeze(-1)], dim=1)
aaT = self._input_covariances[name]
ggT = self._gradient_covariances[name]
x_joint = ggT @ x_joint @ aaT
Expand Down Expand Up @@ -422,7 +422,7 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[1], 1)], dim=1)
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
else:
Expand Down
9 changes: 8 additions & 1 deletion test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def test_kfac(
separate_weight_and_bias=separate_weight_and_bias,
)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
mc_samples=2_000,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
Expand Down

0 comments on commit 2f3bdec

Please sign in to comment.