jax_demo This repo is for those who have experience in deep learning with torch, and want to learn a lesser known but still powerful library JAX Requirements Python 3.10+ pip pip install jax or pip install -U "jax[cuda12]" if using a CUDA GPU