Skip to content

Commit

Permalink
JAX FP8 matmul (fusion) Jupyter Notebook tutorial. (#133)
Browse files Browse the repository at this point in the history
In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of FP8 matmuls,
while still generating an optimal fused kernel call including:
* FP8 inputs scaling;
* FP8 output scaling & clamping;
* Non-linearity & bias fusing;
* Abs-max output capture;

Note: some open questions remain on bias or gelu fusing.
  • Loading branch information
balancap committed Sep 25, 2024
1 parent b9b5c57 commit 4dda0a3
Show file tree
Hide file tree
Showing 3 changed files with 752 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.jupyter_ystore.db

# IPython
profile_default/
Expand Down
Loading

0 comments on commit 4dda0a3

Please sign in to comment.