Skip to content

Commit

Permalink
Merge branch 'development' into kfac-ce-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 30, 2023
2 parents 97693b9 + a331b05 commit beda64b
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
- Weights and biases are treated separately.
- No weight sharing is supported.
- Only the Monte-Carlo sampled version is supported.
- Only the ``'expand'`` setting is supported.
- Only the ``'expand'`` approximation is supported.
Args:
model_func: The neural network. Must consist of modules.
Expand Down Expand Up @@ -280,9 +280,14 @@ def draw_label(self, output: Tensor) -> Tensor:
generator=self._generator,
)
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
Expand All @@ -291,6 +296,7 @@ def draw_label(self, output: Tensor) -> Tensor:
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)

else:
raise NotImplementedError

Expand Down Expand Up @@ -333,7 +339,10 @@ def _hook_accumulate_gradient_covariance(
covariance = einsum("bi,bj->ij", g, g).mul_(correction)
else:
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")
raise NotImplementedError(
f"Layer of type {type(module)} is unsupported. "
+ f"Supported layers: {self._SUPPORTED_MODULES}."
)

idx = tuple(p.data_ptr() for p in module.parameters())
if idx not in self._gradient_covariances:
Expand Down

0 comments on commit beda64b

Please sign in to comment.