Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxdecomp proto #21

Open
wants to merge 92 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
a742065
adding example of distributed solution
EiffL Jul 9, 2024
6408aff
put back old functgion
EiffL Jul 9, 2024
319942a
update formatting
EiffL Jul 9, 2024
ac86468
add halo exchange and slice pad
ASKabalan Jul 9, 2024
e62cd84
apply formatting
ASKabalan Jul 18, 2024
5775a37
implement distributed optimized cic_paint
ASKabalan Jul 18, 2024
7501b5b
Use new cic_paint with halo
ASKabalan Jul 18, 2024
7f48cfa
Fix seed for distributed normal
ASKabalan Jul 18, 2024
c81d4d2
Wrap interpolation function to avoid all gather
ASKabalan Jul 18, 2024
abde543
Return normal order frequencies for single GPU
ASKabalan Jul 18, 2024
82be568
add example
ASKabalan Jul 18, 2024
4f508b7
format
ASKabalan Jul 18, 2024
5f6d42e
add optimised bench script
ASKabalan Jul 18, 2024
1f6b9c3
times in ms
ASKabalan Jul 18, 2024
ed8cf8e
add lpt2
ASKabalan Jul 18, 2024
0216837
update benchmark and add slurm
ASKabalan Jul 18, 2024
5b7f595
Visualize only final field
ASKabalan Jul 18, 2024
1f20351
Update scripts/distributed_pm.py
ASKabalan Jul 19, 2024
f25eb7d
Adjust pencil type for frequencies
ASKabalan Jul 28, 2024
8c5bd76
fix painting issue with slabs
ASKabalan Aug 2, 2024
75604d2
Shared operation in fourrier space now take inverted sharding axis for
ASKabalan Aug 2, 2024
ccbfee3
add assert to make pyright happy
ASKabalan Aug 2, 2024
aebc3e7
adjust test for hpc-plotter
ASKabalan Aug 2, 2024
9af4659
add PMWD test
ASKabalan Aug 2, 2024
831291c
bench
Aug 2, 2024
ece8c93
format
ASKabalan Aug 2, 2024
783a974
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Aug 2, 2024
02754cf
added github workflow
ASKabalan Aug 2, 2024
8da3149
fix formatting from main
ASKabalan Aug 2, 2024
2ea05a1
Update for jaxDecomp pure JAX
Aug 7, 2024
ab86699
merge with JZ
ASKabalan Oct 18, 2024
afecb13
revert single halo extent change
ASKabalan Oct 20, 2024
01b9527
update for latest jaxDecomp
ASKabalan Oct 21, 2024
ff1c5e8
remove fourrier_space in autoshmap
ASKabalan Oct 21, 2024
0ce7219
make normal_field work with single controller
ASKabalan Oct 21, 2024
9c94f99
format
ASKabalan Oct 21, 2024
375f204
make distributed pm work in single controller
ASKabalan Oct 21, 2024
5a587fd
merge bench_pm
ASKabalan Oct 21, 2024
a160a3f
update to leapfrog
ASKabalan Oct 22, 2024
38714cf
add a strict dependency on jaxdecomp
ASKabalan Oct 22, 2024
591ee32
global mesh no longer needed
ASKabalan Oct 22, 2024
a5b267b
kernels.py no longer uses global mesh
ASKabalan Oct 22, 2024
56ffd26
quick fix in distributed
ASKabalan Oct 22, 2024
80c56dc
pm.py no longer uses global mesh
ASKabalan Oct 22, 2024
105568e
painting.py no longer uses global mesh
ASKabalan Oct 22, 2024
4d944f0
update demo script
ASKabalan Oct 22, 2024
a8b194f
quick fix in kernels
ASKabalan Oct 22, 2024
0433c61
quick fix in distributed
ASKabalan Oct 22, 2024
2f50993
update demo
ASKabalan Oct 22, 2024
8623308
merge hugos LPT2 code
ASKabalan Oct 22, 2024
82b8f56
format
ASKabalan Oct 22, 2024
85cca44
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Oct 22, 2024
d28982e
Small fix
ASKabalan Oct 22, 2024
82f2987
format
ASKabalan Oct 22, 2024
31ca41b
remove duplicate get_ode_fn
ASKabalan Oct 22, 2024
cf799b6
update visualizer
ASKabalan Oct 22, 2024
0bb992f
update compensate CIC
ASKabalan Oct 22, 2024
45b2c7f
By default check_rep is false for shard_map
ASKabalan Oct 22, 2024
505f2ec
remove experimental distributed code
ASKabalan Oct 25, 2024
5d4f438
update PGDCorrection and neural ode to use new fft3d
ASKabalan Oct 25, 2024
8e8e896
jaxDecomp pfft3d promotes to complex automatically
ASKabalan Oct 25, 2024
69c35d1
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Oct 25, 2024
0f833f0
remove deprecated stuff
EiffL Oct 24, 2024
d2f1eb2
fix painting issue with read_cic
ASKabalan Oct 26, 2024
ff8856d
use jnp interp instead of jc interp
ASKabalan Oct 26, 2024
0c96a4d
delete old slurms
ASKabalan Oct 26, 2024
49dd18a
add notebook examples
ASKabalan Oct 26, 2024
11f7e90
Merge remote-tracking branch 'upstream/ASKabalan/jaxdecomp_proto' int…
ASKabalan Oct 26, 2024
4342279
apply formatting
ASKabalan Oct 26, 2024
cc4f310
add distributed zeros
ASKabalan Oct 27, 2024
d62c38f
fix code in LPT2
ASKabalan Oct 27, 2024
b4fdb74
jit cic_paint
ASKabalan Oct 27, 2024
c93894f
update notebooks
ASKabalan Oct 27, 2024
19011d0
apply formating
ASKabalan Oct 27, 2024
a757b62
get local shape and zeros can be used by users
ASKabalan Oct 30, 2024
f3b431a
add a user facing function to create uniform particle grid
ASKabalan Oct 30, 2024
2ad035a
use jax interp instead of jax_cosmo
ASKabalan Oct 30, 2024
b3a264a
use float64 for enmeshing
ASKabalan Oct 30, 2024
b09580d
Allow applying weights with relative cic paint
ASKabalan Oct 30, 2024
e9529d3
Weights can be traced
ASKabalan Oct 30, 2024
4da4c66
remove script folder
ASKabalan Oct 30, 2024
72457d6
update example notebooks
ASKabalan Oct 30, 2024
a030ec4
delete outdated design file
ASKabalan Oct 30, 2024
f0c43f8
add readme for tutorials
ASKabalan Oct 30, 2024
a067954
update readme
ASKabalan Oct 30, 2024
42d8e89
fix small error
ASKabalan Oct 30, 2024
6256fba
forgot particles in multi host
ASKabalan Oct 30, 2024
2472a5d
clarifying why cic_paint_dx is slower
ASKabalan Nov 10, 2024
ad45666
clarifying the halo size dependence on the box size
ASKabalan Nov 10, 2024
12c74e2
ability to choose snapshots number with MultiHost script
ASKabalan Nov 10, 2024
0946842
Adding animation notebook
ASKabalan Nov 10, 2024
435c7c8
Put plotting in package
ASKabalan Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Code Formatting

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule
celerybeat.pid


out
traces
*.npy
*.out
# SageMath parsed files
*.sage.py

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver

**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**

## Goals

Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
- Keep implementation simple and readable, in pure NumPy API
- Transparent distribution using builtin `xmap`
- Any order forward and backward automatic differentiation
- Support automated batching using `vmap`
- Compatibility with external optimizer libraries like `optax`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with `JAX v0.4.35`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should not be in the Goals section, it's a feature now.



## Open development and use

Expand All @@ -23,6 +22,10 @@ Current expectations are:
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.

## Getting Started

To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put the link to the README around notebook section.



## Contributors ✨

Expand Down
270 changes: 270 additions & 0 deletions benchmarks/bench_pm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import os

os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax

jax.distributed.initialize()

rank = jax.process_index()
size = jax.process_count()

import argparse
import time

import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn


def run_simulation(mesh_shape,
box_size,
halo_size,
solver_choice,
iterations,
hlo_print,
trace,
pdims=None,
output_path="."):

@jax.jit
def simulate(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)

# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0))

# Create particles
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
if solver_choice == "Dopri5":
solver = Dopri5()
elif solver_choice == "LeapfrogMidpoint":
solver = LeapfrogMidpoint()
elif solver_choice == "Tsit5":
solver = Tsit5()
elif solver_choice == "lpt":
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
return lpt_field, {"num_steps": 0}
else:
raise ValueError(
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))

if solver_choice == "Dopri5" or solver_choice == "Tsit5":
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(t1=True),
stepsize_controller=stepsize_controller)

# Return the simulation volume at requested
state = res.ys[-1]
final_field = cic_paint_dx(state[0], halo_size=halo_size)

return final_field, res.stats

def run():
# Warm start
chrono_fun = Timer()
RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
return final_field, stats, chrono_fun

if jax.device_count() > 1:
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
with mesh:
# Warm start
final_field, stats, chrono_fun = run()
else:
final_field, stats, chrono_fun = run()

return final_field, stats, chrono_fun


if __name__ == "__main__":

parser = argparse.ArgumentParser(
description='JAX Cosmo Simulation Benchmark')
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-p',
'--pdims',
type=str,
help='Processor dimensions',
default=None)
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],
)
parser.add_argument('-hs',
'--halo_size',
type=int,
help='Halo size',
default=None)
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=[
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
"lpt"
],
default="lpt")
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-n',
'--nodes',
type=int,
help='Number of nodes',
default=1)
args = parser.parse_args()
mesh_size = args.mesh_size
box_size = [args.box_size] * 3
halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size
solver_choice = args.solver
iterations = args.iterations
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)

print(f"solver choice: {solver_choice}")
match solver_choice:
case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5"
case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
case _:
raise ValueError(
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
)
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)

if args.pdims:
pdims = tuple(map(int, args.pdims.split("x")))
else:
pdims = (1, jax.device_count())
pdm_str = f"{pdims[0]}x{pdims[1]}"

mesh_shape = [mesh_size] * 3

final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
iterations, pdims)

print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)

metadata = {
'rank': rank,
'function_name': f'JAXPM-{solver_choice}',
'precision': args.precision,
'x': str(mesh_size),
'y': str(mesh_size),
'z': str(stats["num_steps"]),
'px': str(pdims[0]),
'py': str(pdims[1]),
'backend': 'NCCL',
'nodes': str(args.nodes)
}
# Print the results to a CSV file
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)

# Save the final field
nb_gpus = jax.device_count()
pdm_str = f"{pdims[0]}x{pdims[1]}"
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))

field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))

print(f"Finished! ")
print(f"Stats {stats}")
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
print(f"Saving field and logs in {field_folder}")
Loading
Loading