Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to differentiate a broadcasted constructor with CUDA array inputs #1528

Open
BioTurboNick opened this issue Oct 3, 2024 · 2 comments
Labels
CUDA All things GPU help wanted Extra attention is needed

Comments

@BioTurboNick
Copy link
Contributor

Due to a CUDA bug: JuliaGPU/CUDA.jl#2514, crossposted for reference.

Encountered when differentiating over:

using Distributions
using CUDA
a = cu(ones(5)); b = cu(zeros(5));

Normal.(a, b)
@BioTurboNick
Copy link
Contributor Author

It seems the CUDA behavior is not likely to change, so it would need to be addressed in Zygote.

The issue occurs because broadcast_forward wraps a function around the broadcasted constructor to dualize it.

@ToucheSir
Copy link
Member

IIRC this is a fundamental problem with the Type/constructor duality (UnionAlls) in Julia, and would be quite tricky for Zygote to address on its own. The Turing folks ran into a similar issue with broadcasting Distribution constructors, and they solved it by using an intermediate function which doesn't capture the constructor type. e.g.:

make_normal(a, b) = Normal(a, b)
# a = cu(...); b = cu(...)
make_normal.(a, b)

In terms of things to try, one could look at whether specializing

@inline function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = dualize(args...)
return f(ds...)
end
end
on Type removes the capture and thus the (non-zero size) closure.

@ToucheSir ToucheSir added help wanted Extra attention is needed CUDA All things GPU labels Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA All things GPU help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants