Skip to content

Commit

Permalink
[DOC] Add more matrices to visual tour, prettify plots
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 6, 2024
1 parent 8e47e57 commit 896f565
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
81 changes: 65 additions & 16 deletions docs/examples/basic_usage/example_visual_tour.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from torch import nn

from curvlinops import EFLinearOperator, GGNLinearOperator, HessianLinearOperator
from tueplots import bundles

from curvlinops import (
EFLinearOperator,
FisherMCLinearOperator,
GGNLinearOperator,
HessianLinearOperator,
KFACLinearOperator,
)

# make deterministic
torch.manual_seed(0)
Expand Down Expand Up @@ -61,6 +68,7 @@
num_params_layer = [
sum(p.numel() for p in child.parameters()) for child in model.children()
]
num_tensors_layer = [len(list(child.parameters())) for child in model.children()]

loss_function = nn.CrossEntropyLoss(reduction="mean").to(DEVICE)

Expand All @@ -80,6 +88,17 @@
Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader)
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader)
EF_linop = EFLinearOperator(model, loss_function, params, dataloader)
Hessian_blocked_linop = HessianLinearOperator(
model,
loss_function,
params,
dataloader,
block_sizes=[s for s in num_tensors_layer if s != 0],
)
F_linop = FisherMCLinearOperator(model, loss_function, params, dataloader)
KFAC_linop = KFACLinearOperator(
model, loss_function, params, dataloader, separate_weight_and_bias=False
)

# %%
#
Expand All @@ -90,18 +109,27 @@
Hessian_mat = Hessian_linop @ identity
GGN_mat = GGN_linop @ identity
EF_mat = EF_linop @ identity
Hessian_blocked_mat = Hessian_blocked_linop @ identity
F_mat = F_linop @ identity
KFAC_mat = KFAC_linop @ identity

# %%
# Visualization
# -------------
#
# We will show the matrix entries on a shared domain for better comparability.

matrices = [Hessian_mat, GGN_mat, EF_mat]
titles = ["Hessian", "GGN", "Empirical Fisher"]
matrices = [Hessian_mat, GGN_mat, EF_mat, Hessian_blocked_mat, F_mat, KFAC_mat]
titles = [
"Hessian",
"Generalized Gauss-Newton",
"Empirical Fisher",
"Block-diagonal Hessian",
"Monte-Carlo Fisher",
"KFAC",
]

rows, columns = 1, 3
img_width = 7
rows, columns = 2, 3


def plot(
Expand All @@ -121,21 +149,34 @@ def plot(
min_value = min(transform(mat).min() for mat in matrices)
max_value = max(transform(mat).max() for mat in matrices)

fig, axes = plt.subplots(
nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
)
fig, axes = plt.subplots(nrows=rows, ncols=columns, sharex=True, sharey=True)
fig.supxlabel("Layer")
fig.supylabel("Layer")

for idx, (ax, mat, title) in enumerate(zip(axes.flat, matrices, titles)):
ax.set_title(title)
img = ax.imshow(transform(mat), vmin=min_value, vmax=max_value)

# layer structure
for pos in numpy.cumsum(num_params_layer):
# layer blocks
boundaries = [0] + numpy.cumsum(num_params_layer).tolist()
for pos in boundaries:
if pos not in [0, num_params]:
style = {"color": "w", "lw": 0.5, "ls": "--"}
style = {"color": "w", "lw": 0.5, "ls": "-"}
ax.axhline(y=pos - 1, xmin=0, xmax=num_params - 1, **style)
ax.axvline(x=pos - 1, ymin=0, ymax=num_params - 1, **style)

# label positions
label_positions = [
(boundaries[layer_idx] + boundaries[layer_idx + 1]) / 2
for layer_idx in range(len(boundaries) - 1)
if boundaries[layer_idx] != boundaries[layer_idx + 1]
]
labels = [str(i + 1) for i in range(len(label_positions))]
ax.set_xticks(label_positions)
ax.set_xticklabels(labels)
ax.set_yticks(label_positions)
ax.set_yticklabels(labels)

# colorbar
last = idx == len(matrices) - 1
if last:
Expand All @@ -146,16 +187,23 @@ def plot(
return fig, axes


# use `tueplots` to make the plot look pretty
plot_config = plt.rc_context(
bundles.icml2024(column="full", nrows=1.5 * rows, ncols=columns)
)

# %%
#
# We will show their logarithmic absolute value:


def logabs(mat, epsilon=1e-5):
return numpy.log10(numpy.abs(mat) + epsilon)
def logabs(mat, epsilon=1e-6):
return numpy.log10(numpy.clip(numpy.abs(mat), a_min=epsilon, a_max=None))


plot(logabs, transform_title="Logarithmic absolute entries")
with plot_config:
plot(logabs, transform_title="Logarithmic absolute entries")
plt.savefig("curvature_matrices_log_abs.pdf", bbox_inches="tight")

# %%
#
Expand All @@ -166,7 +214,8 @@ def unchanged(mat):
return mat


plot(unchanged, transform_title="Unaltered matrix entries")
with plot_config:
plot(unchanged, transform_title="Unaltered matrix entries")

# %%
#
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ docs = [
"matplotlib",
"sphinx-gallery",
"sphinx-rtd-theme",
"tueplots"
]

###############################################################################
Expand Down

0 comments on commit 896f565

Please sign in to comment.