Skip to content

Commit

Permalink
Increase OUT_DIM to 2 to generalize KFAC test
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 19, 2023
1 parent cc4022d commit 6d1b7f8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
23 changes: 16 additions & 7 deletions test/optim/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
IN_DIM = 3
HID_DIM = 5
REP_DIM = 2
OUT_DIM = 1
OUT_DIM = 2
N_SAMPLES = 4
C_in = 3
C_out = 2
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_kfac_single_linear_module(
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(start_dim=1)
Js = Js.flatten(end_dim=-2)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_kfac_deep_linear(
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(start_dim=1)
Js = Js.flatten(end_dim=-2)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_kfac_conv2d_module(
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(start_dim=1)
Js = Js.flatten(end_dim=-2)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand Down Expand Up @@ -291,10 +291,16 @@ def forward_and_backward(self, x: Tensor):
else {}
)
logits: Tensor = self.model(x, **kwargs)
# Backward pass for each output dimension.
# Since we only consider the MSE loss, we do not need to explicitly
# consider the loss function or labels, as the MSE loss Hessian w.r.t.
# the logits is the precision matrix of the Gaussian likelihood.
# With other words, we only need to compute the Jacobian of the logits
# w.r.t. the parameters. This requires one backward pass per output
# dimension.
n_dims = logits.size(-1)
for i in range(n_dims):
logits_i = logits[:, i]
# Mean or sum reduction over `n_loss_terms`.
loss = logits_i.mean() if self.batch_averaged else logits_i.sum()
loss.backward(retain_graph=i < n_dims - 1)

Expand All @@ -316,7 +322,7 @@ def get_kfac_blocks(self) -> List[Tensor]:
raise ValueError("forward_and_backward() has to be called first.")
# Get Kronecker factor ingredients stored as module attributes.
a: Tensor = module.kfac_a
g: Tensor = module.kfac_g
g: Tensor = torch.cat(module.kfac_g)
# Compute Kronecker product of both factors.
block = torch.kron(g.T @ g, a.T @ a)
# When a bias is used we have to reorder the rows and columns of the
Expand Down Expand Up @@ -391,4 +397,7 @@ def _set_g(
g = process_grad_output(
g, module, batch_averaged=self.batch_averaged, kfac_approx=self.setting
)
module.kfac_g = g
if hasattr(module, "kfac_g"):
module.kfac_g.append(g)
else:
module.kfac_g = [g]
3 changes: 2 additions & 1 deletion test/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ def jacobians_naive(model: Module, data: Tensor, setting: str) -> Tuple[Tensor,
f: Tensor = model(data)
# f: (batch_size/n_loss_terms, ..., out_dim)
out_dim = f.size(-1)
last_f_dim = f.numel() - 1
jacs = []
for i, f_i in enumerate(f.flatten()):
rg = i != (f.shape[0] - 1)
rg = i != last_f_dim
jac = torch.autograd.grad(f_i, model.parameters(), retain_graph=rg)
jacs.append(torch.cat([j.flatten() for j in jac]))
# jacs: (n_loss_terms, out_dim, num_params)
Expand Down

0 comments on commit 6d1b7f8

Please sign in to comment.