diff --git a/test/chainrules.jl b/test/chainrules.jl index 00fd1b0af..3d5fcb035 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -422,11 +422,11 @@ end end @testset "ChainRules translation" begin - @test Zygote.wrap_chainrules_input(nothing) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input((nothing,)) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input([nothing]) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == ChainRules.Tangent{Any}(ChainRules.Tangent{Any}(1.0, 2.0), 3.0) - @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == ChainRules.Tangent{Any}(a = 1.0, b = 2.0) + @test Zygote.wrap_chainrules_input(nothing) == ZeroTangent() + @test Zygote.wrap_chainrules_input((nothing,)) == ZeroTangent() + @test Zygote.wrap_chainrules_input([nothing]) == ZeroTangent() + @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == Tangent{Any}(Tangent{Any}(1.0, 2.0), 3.0) + @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == Tangent{Any}(a = 1.0, b = 2.0) @test Zygote.wrap_chainrules_input(Ref(1)) == 1 @test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0] @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]