-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding fwd/bwd cast methods compatible with FP8.
Allowing cast applying only on forward or backward passes respectively. Making it easier to build explicit FP8 code.
- Loading branch information
Showing
8 changed files
with
136 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import jax | ||
|
||
|
||
@partial(jax.custom_vjp, nondiff_argnums=(0,)) | ||
def map_on_forward(f, arg): | ||
"""Map a function on a forward pass only. No-op/identity on backward pass.""" | ||
return f(arg) | ||
|
||
|
||
def map_on_forward_fwd(f, arg): | ||
return arg, None | ||
|
||
|
||
def map_on_forward_bwd(f, _, grad): | ||
return (grad,) | ||
|
||
|
||
map_on_forward.defvjp(map_on_forward_fwd, map_on_forward_bwd) | ||
|
||
|
||
@partial(jax.custom_vjp, nondiff_argnums=(0,)) | ||
def map_on_backward(f, arg): | ||
"""Map a function on the gradient/backward pass. No-op/identity on forward.""" | ||
return arg | ||
|
||
|
||
def map_on_backward_fwd(f, arg): | ||
return arg, None | ||
|
||
|
||
def map_on_backward_bwd(f, _, grad): | ||
return (f(grad),) | ||
|
||
|
||
map_on_backward.defvjp(map_on_backward_fwd, map_on_backward_bwd) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters