Skip to content

Commit

Permalink
Clean up jacobians_naive
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 19, 2023
1 parent 5b1ac3e commit 96a8e82
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions test/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,19 @@ def verify_dtype(mat: StructuredMatrix, dtype: torch.dtype):
raise RuntimeError(f"Expected dtype {dtype}, got {dtypes}.")


def grad(model: Module) -> Tensor:
return torch.cat([p.grad.data.flatten() for p in model.parameters()]).detach()


def jacobians_naive(model: Module, data: Tensor, setting: str) -> Tuple[Tensor, Tensor]:
model.zero_grad()
num_params = sum(p.numel() for p in model.parameters())
try:
f: Tensor = model(data, setting)
except TypeError:
f: Tensor = model(data)
Jacs = list()
for i in range(f.shape[0]):
if len(f.shape) > 1:
jacs = list()
for j in range(f.shape[1]):
rg = i != (f.shape[0] - 1) or j != (f.shape[1] - 1)
torch.autograd.backward(f[i, j], torch.tensor(1.0), retain_graph=rg)
Jij = grad(model)
jacs.append(Jij)
model.zero_grad()
jacs = torch.stack(jacs).t()
else:
rg = i != (f.shape[0] - 1)
f[i].backward(retain_graph=rg)
jacs = grad(model)
model.zero_grad()
Jacs.append(jacs)
Jacs = torch.stack(Jacs).transpose(1, 2)
return Jacs.detach(), f.detach()
# f: (batch_size/n_loss_terms, ..., out_dim)
out_dim = f.size(-1)
jacs = []
for i, f_i in enumerate(f.flatten()):
rg = i != (f.shape[0] - 1)
jac = torch.autograd.grad(f_i, model.parameters(), retain_graph=rg)
jacs.append(torch.cat([j.flatten() for j in jac]))
# jacs: (n_loss_terms, out_dim, num_params)
jacs = torch.stack(jacs).view(-1, out_dim, num_params)
return jacs.detach(), f.detach()

0 comments on commit 96a8e82

Please sign in to comment.