Skip to content

Latest commit

 

History

History
7 lines (4 loc) · 614 Bytes

210510 GSPMD.md

File metadata and controls

7 lines (4 loc) · 614 Bytes

https://arxiv.org/abs/2105.04663

GSPMD: General and Scalable Parallelization for ML Computation Graphs (Yuanzhong Xu, HyoukJoong Lee, Dehao Chen, Blake Hechtman, Yanping Huang, Rahul Joshi, Maxim Krikun, Dmitry Lepikhin, Andy Ly, Marcello Maggioni, Ruoming Pang, Noam Shazeer, Shibo Wang, Tao Wang, Yonghui Wu, Zhifeng Chen)

모델을 알아서 쪼개서 분산시켜주는 컴파일러. https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html jax의 인터페이스를 참고해볼만 하네요. 이제는 jax를 여러모로 고려해볼 필요가 있지 않을까 싶습니다.

#distributed_training