diff --git a/test/optim/test_kfac.py b/test/optim/test_kfac.py index 73ec23e..4ba7dde 100644 --- a/test/optim/test_kfac.py +++ b/test/optim/test_kfac.py @@ -15,7 +15,7 @@ IN_DIM = 3 HID_DIM = 5 REP_DIM = 2 -OUT_DIM = 1 +OUT_DIM = 2 N_SAMPLES = 4 C_in = 3 C_out = 2 @@ -63,7 +63,7 @@ def test_kfac_single_linear_module( Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(start_dim=1) + Js = Js.flatten(end_dim=-2) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -123,7 +123,7 @@ def test_kfac_deep_linear( Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(start_dim=1) + Js = Js.flatten(end_dim=-2) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -192,7 +192,7 @@ def test_kfac_conv2d_module( Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(start_dim=1) + Js = Js.flatten(end_dim=-2) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -291,10 +291,16 @@ def forward_and_backward(self, x: Tensor): else {} ) logits: Tensor = self.model(x, **kwargs) - # Backward pass for each output dimension. + # Since we only consider the MSE loss, we do not need to explicitly + # consider the loss function or labels, as the MSE loss Hessian w.r.t. + # the logits is the precision matrix of the Gaussian likelihood. + # With other words, we only need to compute the Jacobian of the logits + # w.r.t. the parameters. This requires one backward pass per output + # dimension. n_dims = logits.size(-1) for i in range(n_dims): logits_i = logits[:, i] + # Mean or sum reduction over `n_loss_terms`. loss = logits_i.mean() if self.batch_averaged else logits_i.sum() loss.backward(retain_graph=i < n_dims - 1) @@ -316,7 +322,7 @@ def get_kfac_blocks(self) -> List[Tensor]: raise ValueError("forward_and_backward() has to be called first.") # Get Kronecker factor ingredients stored as module attributes. a: Tensor = module.kfac_a - g: Tensor = module.kfac_g + g: Tensor = torch.cat(module.kfac_g) # Compute Kronecker product of both factors. block = torch.kron(g.T @ g, a.T @ a) # When a bias is used we have to reorder the rows and columns of the @@ -391,4 +397,7 @@ def _set_g( g = process_grad_output( g, module, batch_averaged=self.batch_averaged, kfac_approx=self.setting ) - module.kfac_g = g + if hasattr(module, "kfac_g"): + module.kfac_g.append(g) + else: + module.kfac_g = [g] diff --git a/test/optim/utils.py b/test/optim/utils.py index f08ead8..d26ef92 100644 --- a/test/optim/utils.py +++ b/test/optim/utils.py @@ -89,9 +89,10 @@ def jacobians_naive(model: Module, data: Tensor, setting: str) -> Tuple[Tensor, f: Tensor = model(data) # f: (batch_size/n_loss_terms, ..., out_dim) out_dim = f.size(-1) + last_f_dim = f.numel() - 1 jacs = [] for i, f_i in enumerate(f.flatten()): - rg = i != (f.shape[0] - 1) + rg = i != last_f_dim jac = torch.autograd.grad(f_i, model.parameters(), retain_graph=rg) jacs.append(torch.cat([j.flatten() for j in jac])) # jacs: (n_loss_terms, out_dim, num_params)