From 2f3bdec79165fa94ac0f6fa8e06f0f3899b3b22f Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 7 Nov 2023 21:30:09 -0500 Subject: [PATCH] [FIX] Make tests work --- curvlinops/kfac.py | 4 ++-- test/test_kfac.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 3151030..37976a7 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -197,7 +197,7 @@ def _matvec(self, x: ndarray) -> ndarray: mod.weight, mod.bias ): w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias) - x_joint = cat([x_torch[w_pos], x_torch[b_pos]], dim=1) + x_joint = cat([x_torch[w_pos], x_torch[b_pos].unsqueeze(-1)], dim=1) aaT = self._input_covariances[name] ggT = self._gradient_covariances[name] x_joint = ggT @ x_joint @ aaT @@ -422,7 +422,7 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor self.in_params(module.weight, module.bias) and not self._separate_weight_and_bias ): - x = cat([x, x.new_ones(x.shape[1], 1)], dim=1) + x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) covariance = einsum("bi,bj->ij", x, x).div_(self._N_data) else: diff --git a/test/test_kfac.py b/test/test_kfac.py index 1e984ca..983eeeb 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -57,7 +57,14 @@ def test_kfac( separate_weight_and_bias=separate_weight_and_bias, ) - kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + mc_samples=2_000, + separate_weight_and_bias=separate_weight_and_bias, + ) kfac_mat = kfac @ eye(kfac.shape[1]) atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]