-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introducing partial discharge rules and implementations for cond_p an…
…d for_p As things stand you can partially discharge a jaxpr with `discharge_state(should_discharge=[...])` but each equation is discharges *all* its arguments. This means that primitives like `scan_p` and `cond_p` discharge all references they refer to (no pun intended) regardless of whether the user asked for it. We provide a special discharge rule that is preferred to the normal one when present that allows the op to discharge only some of the references. This feature is especially useful for pallas kernels because contrary to all other contexts where jaxprs are expected to eventually be fully discharged, pallas kernels lower references all the way to the runtime as pointers or MLIR memrefs. PiperOrigin-RevId: 678665716
- Loading branch information
1 parent
5cef547
commit 961680c
Showing
4 changed files
with
130 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters