Skip to content

Commit

Permalink
[FIX] Use correct device in test
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 8, 2023
1 parent 6106ad6 commit 5601596
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def test_kfac_inplace_activations(dev: device):
dev: The device to run the test on.
"""
manual_seed(0)
model = Sequential(Linear(6, 3), ReLU(inplace=True), Linear(3, 2))
loss_func = MSELoss()
model = Sequential(Linear(6, 3), ReLU(inplace=True), Linear(3, 2)).to(dev)
loss_func = MSELoss().to(dev)
batch_size = 1
data = [(rand(batch_size, 6), regression_targets((batch_size, 2)))]
params = list(model.parameters())
Expand Down

0 comments on commit 5601596

Please sign in to comment.