Skip to content

Commit

Permalink
[FIX] Examples
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 23, 2024
1 parent e678ea2 commit 32523ad
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions docs/examples/basic_usage/example_fisher_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# identity matrix:

GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy()
F = FisherMCLinearOperator(model, loss_function, params, data)
F = FisherMCLinearOperator(model, loss_function, params, data).to_scipy()

D = GGN.shape[0]
identity = numpy.eye(D)
Expand Down Expand Up @@ -109,7 +109,9 @@
residual_norms = []

for mc in mc_samples:
F = FisherMCLinearOperator(model, loss_function, params, data, mc_samples=mc)
F = FisherMCLinearOperator(
model, loss_function, params, data, mc_samples=mc
).to_scipy()
F_mat = F @ identity
residual_norms.append(numpy.linalg.norm(GGN_mat - F_mat))

Expand All @@ -131,12 +133,16 @@
# first linear operator, we generate some random numbers to show that the
# global random number generator does not influence the Monte-Carlo estimator:

F1_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity
F1_mat = (
FisherMCLinearOperator(model, loss_function, params, data).to_scipy() @ identity
)

# draw some random numbers to modify the global random number generator's state
torch.rand(123)

F2_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity
F2_mat = (
FisherMCLinearOperator(model, loss_function, params, data).to_scipy() @ identity
)

# still, we get the same deterministic approximation
residual_norm = numpy.linalg.norm(F1_mat - F2_mat)
Expand Down Expand Up @@ -201,7 +207,7 @@
# linear operators indeed realize deterministic matrices.
F = FisherMCLinearOperator(
model, loss_function, params, data, seed=seed, check_deterministic=False
)
).to_scipy()
F_accumulated += F @ identity
if mc + 1 in mc_samples:
F_snapshots.append(F_accumulated / (mc + 1))
Expand Down

0 comments on commit 32523ad

Please sign in to comment.