-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to differentiate a broadcasted constructor with CUDA array inputs #1528
Comments
It seems the CUDA behavior is not likely to change, so it would need to be addressed in Zygote. The issue occurs because |
IIRC this is a fundamental problem with the make_normal(a, b) = Normal(a, b)
# a = cu(...); b = cu(...)
make_normal.(a, b) In terms of things to try, one could look at whether specializing Zygote.jl/src/lib/broadcast.jl Lines 273 to 278 in 2517e67
Type removes the capture and thus the (non-zero size) closure.
|
Due to a CUDA bug: JuliaGPU/CUDA.jl#2514, crossposted for reference.
Encountered when differentiating over:
The text was updated successfully, but these errors were encountered: