Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ZygoteDistancesExt and associated tests #1460

Closed
wants to merge 1 commit into from

Conversation

simsurace
Copy link
Contributor

This extension should probably be removed as

  • It seems to be broken with ChainRules 1.53.0
  • Distances.jl now has a ChainRulesCore extension

@devmotion
Copy link
Collaborator

To make this not completely breaking I think one has to keep these rules for all Distances versions that do not contain the ChainRules definitions. At least on Julia < 1.9 users won't be able to use Distances with an upcoming Zygote release anymore, even if they update Distances, since Distances only defines these ChainRules defs on Julia >= 1.9.

@ToucheSir
Copy link
Member

ToucheSir commented Oct 6, 2023

Zygote still has to support Julia 1.6+ and extensions only work on 1.9+, so unless Distances.jl adds a fallback with a direct dep/Requires.jl I don't see how we could remove the Zygote rules? Edit: @devmotion beat me by a second! Would it be so painful to have one of the aforementioned fallbacks on the Distances side?

@devmotion
Copy link
Collaborator

There's a long history of discussions with Distances maintainers about adding ChainRulesCore that already started before extensions were a thing but they did not approve the idea, so I don't think it's likely to be added as a direct dependency. Not sure about Requires but I assume there might be similar concerns about increasing dependencies and loading times.

@ToucheSir
Copy link
Member

My recollection is that those discussions happened when CRC was a much heavier dep than it is now, and the topic hasn't really been revisited since. The problem with us gating the Zygote rules behind a version check is what @simsurace mentioned up top: they've bitrotted to the point where they're no longer super functional. Unless someone wants to put in the effort of fixing them, there's not too much we can do on the Zygote side to ensure users access to functional rules for Distances.jl on Julia <1.9.

@simsurace
Copy link
Contributor Author

Ok, maybe I haven't diagnosed the problem correctly.
Here is a MRE

Zygote#master, ChainRules@1.55.0, Distances@0.10.10:

julia> using Distances, Zygote

julia> x = rand(10);

julia> f(x) = iszero(x) ? zero(x) : x;

julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
ERROR: MethodError: no method matching *(::Nothing, ::Float64)

Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:578
  *(::T, ::T) where T<:Union{Float16, Float32, Float64}
   @ Base float.jl:410
  *(::StridedArray{P}, ::Real) where P<:Dates.Period
   @ Dates ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/Dates/src/deprecated.jl:44
  ...

Stacktrace:
  [1] (::Zygote.var"#1412#1416"{Int64})(y1::Nothing, o1::ForwardDiff.Dual{Nothing, Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:683 [inlined]
  [3] _broadcast_getindex
    @ ./broadcast.jl:656 [inlined]
  [4] getindex
    @ ./broadcast.jl:610 [inlined]
  [5] macro expansion
    @ ./broadcast.jl:974 [inlined]
  [6] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [7] copyto!
    @ ./broadcast.jl:973 [inlined]
  [8] copyto!
    @ ./broadcast.jl:926 [inlined]
  [9] copy
    @ ./broadcast.jl:898 [inlined]
 [10] materialize
    @ ./broadcast.jl:873 [inlined]
 [11] broadcast(::Zygote.var"#1412#1416"{Int64}, ::Matrix{Union{Nothing, Float64}}, ::Matrix{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Base.Broadcast ./broadcast.jl:811
 [12] #1411
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298 [inlined]
 [13] ntuple
    @ ./ntuple.jl:49 [inlined]
 [14] bc_fwd_back
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:297 [inlined]
 [15] #4155#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [16] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [17] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [18] Pullback
    @ ./broadcast.jl:1317 [inlined]
 [19] Pullback
    @ ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:104 [inlined]
 [20] (::ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}})(Δ::Matrix{Union{Nothing, Float64}})
    @ ZygoteDistancesExt ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:107
 [21] Pullback
    @ ./REPL[86]:1 [inlined]
 [22] (::Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [24] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:97
 [25] top-level scope
    @ REPL[86]:1

Zygote#master, ChainRules@1.52.1, Distances@0.10.10:

julia> using Distances, Zygote

julia> x = rand(10);

julia> f(x) = iszero(x) ? zero(x) : x;

julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
([-9.999999999999956, 6.0, -17.999999999999996, 1.9999999999999984, -14.000000000000028, 10.000000000000766, -2.000000000000006, 18.000000000000004, 13.999999999999257, -6.000000000000013],)

Zygote#simsurace:remove-distances, ChainRules@1.55.0, Distances@0.10.10: maybe another bug in Zygote, see JuliaStats/Distances.jl#256

julia> using Distances, Zygote

julia> x = rand(10);

julia> f(x) = iszero(x) ? zero(x) : x;

julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
ERROR: MethodError: no method matching _normalize(::ChainRulesCore.ZeroTangent, ::Float64)

Closest candidates are:
  _normalize(::Real, ::Real)
   @ DistancesChainRulesCoreExt ~/.julia/packages/Distances/PvoXa/ext/DistancesChainRulesCoreExt.jl:83

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:683 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:656 [inlined]
  [3] getindex
    @ ./broadcast.jl:610 [inlined]
  [4] copy
    @ ./broadcast.jl:912 [inlined]
  [5] materialize
    @ ./broadcast.jl:873 [inlined]
  [6] (::DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}})(ΔΩ::Matrix{Any})
    @ DistancesChainRulesCoreExt ~/.julia/packages/Distances/PvoXa/ext/DistancesChainRulesCoreExt.jl:114
  [7] ZBack
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:211 [inlined]
  [8] (::Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}})(dy::Matrix{Union{Nothing, Float64}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:237
  [9] Pullback
    @ ./REPL[7]:1 [inlined]
 [10] (::Zygote.Pullback{Tuple{var"#3#4", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#3#4", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [13] top-level scope
    @ REPL[7]:1

@simsurace
Copy link
Contributor Author

Closing this for now as it seems to be considered too breaking. I opened #1464 to track the issue that has been unmasked.

@simsurace simsurace closed this Oct 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants