Skip to content

Commit

Permalink
JAX FP8 matmul (fusion) Jupyter Notebook tutorial.
Browse files Browse the repository at this point in the history
In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of FP8 matmuls,
while still generating an optimal fused kernel call including:
* FP8 inputs scaling;
* FP8 output scaling & clamping;
* Non-linearity & bias fusing;
* Abs-max output capture;

Note: some open questions remain on bias or gelu fusing.
  • Loading branch information
balancap committed Sep 25, 2024
1 parent b9b5c57 commit a203df7
Show file tree
Hide file tree
Showing 3 changed files with 752 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.jupyter_ystore.db

# IPython
profile_default/
Expand Down
Loading

0 comments on commit a203df7

Please sign in to comment.