From 4c470eb3804e2d9cb75467493fc3c2372cd8b44a Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Fri, 1 Sep 2023 12:46:54 -0700 Subject: [PATCH 1/3] Remove GPU sum() rule --- src/lib/broadcast.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 504ef614d..02d839ec3 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -364,11 +364,6 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve @adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} = T(xs), Δ -> (convert(Array, Δ), ) - @adjoint function sum(xs::AbstractGPUArray; dims = :) - placeholder = similar(xs) - sum(xs, dims = dims), Δ -> (placeholder .= Δ,) - end - # Make sure sum(f, ::CuArray) uses broadcast through forward-mode defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray) From 33946f3722391866a6a0dfdd9b997501c45893af Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 4 Sep 2023 18:20:54 -0700 Subject: [PATCH 2/3] Try removing Fill sum rule too --- src/lib/array.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index 37884cded..8577852ad 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -329,6 +329,7 @@ end end # Reductions +#= @adjoint function sum(xs::AbstractArray; dims = :) if dims === (:) sum(xs), Δ -> (Fill(Δ, size(xs)),) @@ -336,6 +337,7 @@ end sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,) end end +=# @adjoint function sum(xs::AbstractArray{Bool}; dims = :) sum(xs, dims = dims), Δ -> (nothing,) From a32f0394cf955794bc83223c09d5c1a13a9f6214 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Tue, 5 Sep 2023 20:40:33 -0700 Subject: [PATCH 3/3] Remove bool rule too and correct test --- src/lib/array.jl | 13 ------------- test/lib/array.jl | 4 ++-- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 8577852ad..489ee2fab 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -329,19 +329,6 @@ end end # Reductions -#= -@adjoint function sum(xs::AbstractArray; dims = :) - if dims === (:) - sum(xs), Δ -> (Fill(Δ, size(xs)),) - else - sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,) - end -end -=# - -@adjoint function sum(xs::AbstractArray{Bool}; dims = :) - sum(xs, dims = dims), Δ -> (nothing,) -end function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray) return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs) diff --git a/test/lib/array.jl b/test/lib/array.jl index a3b73aff9..9afe43673 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -50,8 +50,8 @@ end @testset "dictionary comprehension" begin d = Dict(1 => 5, 2 => 6) g = gradient(d -> sum([v^2 for (_,v) in d]), d)[1] - @test g isa Dict{Int, Int} - @test g == Dict(1 => 10, 2 => 12) + @test g isa Dict{Int, Float64} + @test g == Dict(1 => 10.0, 2 => 12.0) w = randn(5)