-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with `discharge_state(should_discharge=[...])` but each equation is discharges *all* its arguments. This means that primitives like `scan_p` and `cond_p` discharge all references they refer to (no pun intended) regardless of whether the user asked for it. We provide a special discharge rule that is preferred to the normal one when present that allows the op to discharge only some of the references. This feature is especially useful for pallas kernels because contrary to all other contexts where jaxprs are expected to eventually be fully discharged, pallas kernels lower references all the way to the runtime as pointers or MLIR memrefs. Here we implement the partial discharge rule for `cond_p` and will implement it for others in due course. PiperOrigin-RevId: 678665716
- Loading branch information
1 parent
b3fca90
commit f8c8630
Showing
3 changed files
with
89 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters