Modify clipping implementation to avoid jnp.moveaxis, which causes undesirable all-to-all's in distributed environments. #4629
Job | Run time |
---|---|
18s | |
12m 52s | |
12m 52s | |
12m 18s | |
12m 16s | |
50m 36s |
Job | Run time |
---|---|
18s | |
12m 52s | |
12m 52s | |
12m 18s | |
12m 16s | |
50m 36s |