Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an error for broadcasting with CUDA + complex numbers, etc #1225

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,13 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs

_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
_dual_purefun(::Type) = false
_dual_purefun(::Type{F}) where {F} = Base.issingletontype(F)
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers

_dual_safearg(x::Numeric{<:Real}) = true
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
_dual_safearg(x) = false
_dual_safearg(x::Union{Val, Symbol, Char, AbstractString}) = true # non-differentiable types
_dual_safearg(x::T) where {T} = Base.issingletontype(T) || Base.issingletontype(eltype(T))

@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
Expand Down Expand Up @@ -226,9 +225,9 @@ end
import ForwardDiff
using ForwardDiff: Dual

dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
dual(x::Bool, p) = x
dual(x::Bool, p) = x # must ignore
dual(x, p) = x # safe to ignore: trust _dual_safearg() elsewhere

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
Expand All @@ -239,7 +238,14 @@ function dual_function(f::F) where F
end
end

@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
@inline function broadcast_forward(f::F, args::Vararg{Any,N}) where {F,N}
Base.issingletontype(F) || @warn ("""Zygote's dual number broadcasting (as used on GPU arrays) cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $(F)""") maxlog=1 _id=hash(F)
for a in args
_dual_safearg(a) || error("""Zygote's dual number broadcasting (as used on GPU arrays) cannot handle this argument.
typeof(a) = $(typeof(a))""")
end
valN = Val(N)
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
Expand Down
13 changes: 13 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ end
@test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
# non-differentiables
@test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing

# Errors -- #1215
y = complex.([4,1]) |> cu
x = complex.([3,2]) |> cu
function f1215(x, y)
x = 2 .* x
return sum(abs2.(x .- y))
end
@test_throws ErrorException gradient(()-> f1215(x,y), Zygote.Params([x]))

# From #1018
@test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test_skip gradient((x,y) -> sum((z->z^2+y[1]).(x)), cu([1,2,3]), cu([4,5])) # if not right, should ideally be an error
end

@testset "sum(f, x)" begin
Expand Down