Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Disable more unsafe casting #162

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented May 16, 2024

I discovered more ways that casting can cause problems in seemingly innocuous code. In particular, if you ask Tapir to differentiate

function h(p)
    a = [1,2,3,4,5]
    buf_view = view(a,3:4)
    buf_view[1] = p
    return a[3]*2.0
end

at p = 2.0, the answer is currently zero. I was amazed that this works at all until I realised that Julia will convert p to an Int when it tries to write it to buf_view if it's the case that p happens to be integer valued. If this happens, gradients get dropped and the wrong answer is given.

As part of this PR, I went back over the intrinsics involved in casting to check if there are any more, and I don't believe there are any more which risk causing problems -- hopefully that will prove to be the case.

edit: I've converted this to a WIP because it's going to take a little bit of time to figure out what's going on in all of the cases where fptosi and fptoui are used in actually innocuous code that doesn't result in dropped gradient info. For the most part, it's just going to be declaring things non-differentiable, but there might be a couple of tricky cases.

This may also motivate a more general approach to this in which we have a macro / trait which can be applied to methods of functions which asserts that it's fine to "drop" gradients for all code inside the method, as we're confident that it's not doing it in a way which risks giving the wrong answer, but that they should otherwise be differentiated as usual. Writing the macro / trait / whatever winds up being a convenient approach to this ought really to be straightforward.

Copy link
Contributor

Performance Ratio:

┌────────────────────────────┬────────┬─────────┬─────────────┬─────────┐
│                      Label │  Tapir │  Zygote │ ReverseDiff │  Enzyme │
│                     String │ String │  String │      String │  String │
├────────────────────────────┼────────┼─────────┼─────────────┼─────────┤
│                        sum │   28.9 │   0.327 │        2.25 │   0.632 │
│                       _sum │    6.6 │   494.0 │        28.0 │   0.121 │
│                   kron_sum │   76.7 │    3.14 │       198.0 │    22.6 │
│              kron_view_sum │   90.6 │    11.1 │       221.0 │    7.69 │
│      naive_map_sin_cos_exp │   4.35 │ missing │        8.86 │    2.82 │
│            map_sin_cos_exp │    4.7 │     1.7 │        7.52 │    3.43 │
│      broadcast_sin_cos_exp │    4.4 │    2.62 │        1.68 │    2.87 │
│                 simple_mlp │   9.24 │    3.31 │        11.6 │    3.41 │
│                     gp_lml │   15.9 │    4.22 │     missing │ missing │
│ turing_broadcast_benchmark │   6.28 │ missing │        33.9 │ missing │
└────────────────────────────┴────────┴─────────┴─────────────┴─────────┘

@willtebbutt willtebbutt changed the title Disable more unsafe casting WIP: Disable more unsafe casting May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant