Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BioTurboNick committed Sep 16, 2024
1 parent ee22f3a commit 74b034d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ end
end

@testset "ChainRules translation" begin
@test Zygote.wrap_chainrules_input(nothing) == ChainRules.ZeroTangent()
@test Zygote.wrap_chainrules_input((nothing,)) == ChainRules.ZeroTangent()
@test Zygote.wrap_chainrules_input([nothing]) == ChainRules.ZeroTangent()
@test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == ChainRules.Tangent{Any}(ChainRules.Tangent{Any}(1.0, 2.0), 3.0)
@test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == ChainRules.Tangent{Any}(a = 1.0, b = 2.0)
@test Zygote.wrap_chainrules_input(nothing) == ZeroTangent()
@test Zygote.wrap_chainrules_input((nothing,)) == ZeroTangent()
@test Zygote.wrap_chainrules_input([nothing]) == ZeroTangent()
@test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == Tangent{Any}(Tangent{Any}(1.0, 2.0), 3.0)
@test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == Tangent{Any}(a = 1.0, b = 2.0)
@test Zygote.wrap_chainrules_input(Ref(1)) == 1
@test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0]
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
Expand Down

0 comments on commit 74b034d

Please sign in to comment.