High-level array sharding API for JAX
einshard
is a Python library designed to simplify the process of sharding and replicating arrays in JAX. einshard
enables integration of various parallelism techniques without modifying the model code. Whether working with simple models like MLPs or larger models like LLMs, einshard
provides a solution for distributing computations across multiple devices. This library allows users to define sharding strategies with simple expressions.
This project originated as a part of the Mistral 7B v0.2 JAX project and has since evolved into an independent project.
This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).
This library requires at least Python 3.12.
You need to have JAX installed by choosing the correct installation method before installing einshard.
After JAX is installed, install einshard with this command:
pip install einshard
Please see the detailed documentation at https://einshard.readthedocs.io/en/latest/.
We have a draft PDF file for the theory.