Skip to content

Commit

Permalink
test: more recurrent testing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 22, 2024
1 parent f135943 commit c27fcfc
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

using MLDataDevices

MLDataDevices.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
MLDataDevices.get_device_type(_) = Nothing # FIXME: upstream maybe?
MLDataDevices.Internal.get_device_type(::Function) = Nothing # FIXME: upstream maybe?
MLDataDevices.Internal.get_device_type(_) = Nothing # FIXME: upstream maybe?

function loss_loop(cell, x, p, st)
(y, carry), st_ = cell(x, p, st)
Expand Down Expand Up @@ -43,9 +43,9 @@ end
@jet rnncell((x, carry), ps, st)

if train_state
@test hasproperty(ps, :train_state)
@test hasproperty(ps, :hidden_state)
else
@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :hidden_state)
end

@test_gradients(loss_loop, rnncell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -95,8 +95,8 @@ end
@jet lstmcell(x, ps, st)
@jet lstmcell((x, carry), ps, st)

@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :train_memory)
@test !hasproperty(ps, :hidden_state)
@test !hasproperty(ps, :memory)

@test_gradients(loss_loop, lstmcell, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end
Expand Down Expand Up @@ -198,7 +198,7 @@ end
@jet grucell(x, ps, st)
@jet grucell((x, carry), ps, st)

@test !hasproperty(ps, :train_state)
@test !hasproperty(ps, :hidden_state)

@test_gradients(loss_loop, grucell, x, ps, st; atol=1e-3, rtol=1e-3)
end
Expand Down

0 comments on commit c27fcfc

Please sign in to comment.