Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 5, 2024
1 parent 4a436d1 commit 98c9e9b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
3 changes: 0 additions & 3 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ def functorch_empirical_fisher(
Returns:
Square matrix containing the empirical Fisher.
Raises:
ValueError: If the loss function's reduction cannot be determined.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data, input_key, device=dev)
Expand Down
18 changes: 4 additions & 14 deletions test/test_submatrix_on_curvatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,9 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case):
A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn)
A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs)

if operator_case == EFLinearOperator:
A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
input_key="x",
)
else:
A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](
model_func, loss_func, params, data, "x"
)
A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](
model_func, loss_func, params, data, "x"
)
A_sub_functorch = A_functorch[row_idxs, :][:, col_idxs].detach().cpu().numpy()

return A_sub, A_sub_functorch, row_idxs, col_idxs
Expand Down Expand Up @@ -111,4 +101,4 @@ def test_SubmatrixLinearOperator_on_curvatures_matmat(
A_sub_X = A_sub @ X

assert A_sub_X.shape == (len(row_idxs), num_vecs)
report_nonclose(A_sub_X, A_sub_functorch @ X, atol=6e-7)
report_nonclose(A_sub_X, A_sub_functorch @ X, atol=1e-6)

0 comments on commit 98c9e9b

Please sign in to comment.