-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxdecomp proto #21
Open
ASKabalan
wants to merge
92
commits into
main
Choose a base branch
from
ASKabalan/jaxdecomp_proto
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
jaxdecomp proto #21
Changes from all commits
Commits
Show all changes
92 commits
Select commit
Hold shift + click to select a range
a742065
adding example of distributed solution
EiffL 6408aff
put back old functgion
EiffL 319942a
update formatting
EiffL ac86468
add halo exchange and slice pad
ASKabalan e62cd84
apply formatting
ASKabalan 5775a37
implement distributed optimized cic_paint
ASKabalan 7501b5b
Use new cic_paint with halo
ASKabalan 7f48cfa
Fix seed for distributed normal
ASKabalan c81d4d2
Wrap interpolation function to avoid all gather
ASKabalan abde543
Return normal order frequencies for single GPU
ASKabalan 82be568
add example
ASKabalan 4f508b7
format
ASKabalan 5f6d42e
add optimised bench script
ASKabalan 1f6b9c3
times in ms
ASKabalan ed8cf8e
add lpt2
ASKabalan 0216837
update benchmark and add slurm
ASKabalan 5b7f595
Visualize only final field
ASKabalan 1f20351
Update scripts/distributed_pm.py
ASKabalan f25eb7d
Adjust pencil type for frequencies
ASKabalan 8c5bd76
fix painting issue with slabs
ASKabalan 75604d2
Shared operation in fourrier space now take inverted sharding axis for
ASKabalan ccbfee3
add assert to make pyright happy
ASKabalan aebc3e7
adjust test for hpc-plotter
ASKabalan 9af4659
add PMWD test
ASKabalan 831291c
bench
ece8c93
format
ASKabalan 783a974
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan 02754cf
added github workflow
ASKabalan 8da3149
fix formatting from main
ASKabalan 2ea05a1
Update for jaxDecomp pure JAX
ab86699
merge with JZ
ASKabalan afecb13
revert single halo extent change
ASKabalan 01b9527
update for latest jaxDecomp
ASKabalan ff1c5e8
remove fourrier_space in autoshmap
ASKabalan 0ce7219
make normal_field work with single controller
ASKabalan 9c94f99
format
ASKabalan 375f204
make distributed pm work in single controller
ASKabalan 5a587fd
merge bench_pm
ASKabalan a160a3f
update to leapfrog
ASKabalan 38714cf
add a strict dependency on jaxdecomp
ASKabalan 591ee32
global mesh no longer needed
ASKabalan a5b267b
kernels.py no longer uses global mesh
ASKabalan 56ffd26
quick fix in distributed
ASKabalan 80c56dc
pm.py no longer uses global mesh
ASKabalan 105568e
painting.py no longer uses global mesh
ASKabalan 4d944f0
update demo script
ASKabalan a8b194f
quick fix in kernels
ASKabalan 0433c61
quick fix in distributed
ASKabalan 2f50993
update demo
ASKabalan 8623308
merge hugos LPT2 code
ASKabalan 82b8f56
format
ASKabalan 85cca44
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan d28982e
Small fix
ASKabalan 82f2987
format
ASKabalan 31ca41b
remove duplicate get_ode_fn
ASKabalan cf799b6
update visualizer
ASKabalan 0bb992f
update compensate CIC
ASKabalan 45b2c7f
By default check_rep is false for shard_map
ASKabalan 505f2ec
remove experimental distributed code
ASKabalan 5d4f438
update PGDCorrection and neural ode to use new fft3d
ASKabalan 8e8e896
jaxDecomp pfft3d promotes to complex automatically
ASKabalan 69c35d1
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan 0f833f0
remove deprecated stuff
EiffL d2f1eb2
fix painting issue with read_cic
ASKabalan ff8856d
use jnp interp instead of jc interp
ASKabalan 0c96a4d
delete old slurms
ASKabalan 49dd18a
add notebook examples
ASKabalan 11f7e90
Merge remote-tracking branch 'upstream/ASKabalan/jaxdecomp_proto' int…
ASKabalan 4342279
apply formatting
ASKabalan cc4f310
add distributed zeros
ASKabalan d62c38f
fix code in LPT2
ASKabalan b4fdb74
jit cic_paint
ASKabalan c93894f
update notebooks
ASKabalan 19011d0
apply formating
ASKabalan a757b62
get local shape and zeros can be used by users
ASKabalan f3b431a
add a user facing function to create uniform particle grid
ASKabalan 2ad035a
use jax interp instead of jax_cosmo
ASKabalan b3a264a
use float64 for enmeshing
ASKabalan b09580d
Allow applying weights with relative cic paint
ASKabalan e9529d3
Weights can be traced
ASKabalan 4da4c66
remove script folder
ASKabalan 72457d6
update example notebooks
ASKabalan a030ec4
delete outdated design file
ASKabalan f0c43f8
add readme for tutorials
ASKabalan a067954
update readme
ASKabalan 42d8e89
fix small error
ASKabalan 6256fba
forgot particles in multi host
ASKabalan 2472a5d
clarifying why cic_paint_dx is slower
ASKabalan ad45666
clarifying the halo size dependence on the box size
ASKabalan 12c74e2
ability to choose snapshots number with MultiHost script
ASKabalan 0946842
Adding animation notebook
ASKabalan 435c7c8
Put plotting in package
ASKabalan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
name: Code Formatting | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip isort | ||
python -m pip install pre-commit | ||
- name: Run pre-commit | ||
run: python -m pre_commit run --all-files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,4 @@ repos: | |
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
name: isort (python) | ||
name: isort (python) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,16 +4,15 @@ | |
<!-- ALL-CONTRIBUTORS-BADGE:END --> | ||
JAX-powered Cosmological Particle-Mesh N-body Solver | ||
|
||
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)** | ||
|
||
## Goals | ||
|
||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX: | ||
- Keep implementation simple and readable, in pure NumPy API | ||
- Transparent distribution using builtin `xmap` | ||
- Any order forward and backward automatic differentiation | ||
- Support automated batching using `vmap` | ||
- Compatibility with external optimizer libraries like `optax` | ||
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35` | ||
|
||
|
||
## Open development and use | ||
|
||
|
@@ -23,6 +22,10 @@ Current expectations are: | |
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal). | ||
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers. | ||
|
||
## Getting Started | ||
|
||
To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would put the link to the README around |
||
|
||
|
||
## Contributors ✨ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
import os | ||
|
||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax | ||
import jax | ||
|
||
jax.distributed.initialize() | ||
|
||
rank = jax.process_index() | ||
size = jax.process_count() | ||
|
||
import argparse | ||
import time | ||
|
||
import jax.numpy as jnp | ||
import jax_cosmo as jc | ||
import numpy as np | ||
from cupy.cuda.nvtx import RangePop, RangePush | ||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, | ||
PIDController, SaveAt, Tsit5, diffeqsolve) | ||
from hpc_plotter.timer import Timer | ||
from jax.experimental import mesh_utils | ||
from jax.experimental.multihost_utils import sync_global_devices | ||
from jax.sharding import Mesh, NamedSharding | ||
from jax.sharding import PartitionSpec as P | ||
|
||
from jaxpm.kernels import interpolate_power_spectrum | ||
from jaxpm.painting import cic_paint_dx | ||
from jaxpm.pm import linear_field, lpt, make_ode_fn | ||
|
||
|
||
def run_simulation(mesh_shape, | ||
box_size, | ||
halo_size, | ||
solver_choice, | ||
iterations, | ||
hlo_print, | ||
trace, | ||
pdims=None, | ||
output_path="."): | ||
|
||
@jax.jit | ||
def simulate(omega_c, sigma8): | ||
# Create a small function to generate the matter power spectrum | ||
k = jnp.logspace(-4, 1, 128) | ||
pk = jc.power.linear_matter_power( | ||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) | ||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) | ||
|
||
# Create initial conditions | ||
initial_conditions = linear_field(mesh_shape, | ||
box_size, | ||
pk_fn, | ||
seed=jax.random.PRNGKey(0)) | ||
|
||
# Create particles | ||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) | ||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) | ||
if solver_choice == "Dopri5": | ||
solver = Dopri5() | ||
elif solver_choice == "LeapfrogMidpoint": | ||
solver = LeapfrogMidpoint() | ||
elif solver_choice == "Tsit5": | ||
solver = Tsit5() | ||
elif solver_choice == "lpt": | ||
lpt_field = cic_paint_dx(dx, halo_size=halo_size) | ||
return lpt_field, {"num_steps": 0} | ||
else: | ||
raise ValueError( | ||
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") | ||
# Evolve the simulation forward | ||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) | ||
term = ODETerm( | ||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) | ||
|
||
if solver_choice == "Dopri5" or solver_choice == "Tsit5": | ||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) | ||
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": | ||
stepsize_controller = ConstantStepSize() | ||
res = diffeqsolve(term, | ||
solver, | ||
t0=0.1, | ||
t1=1., | ||
dt0=0.01, | ||
y0=jnp.stack([dx, p], axis=0), | ||
args=cosmo, | ||
saveat=SaveAt(t1=True), | ||
stepsize_controller=stepsize_controller) | ||
|
||
# Return the simulation volume at requested | ||
state = res.ys[-1] | ||
final_field = cic_paint_dx(state[0], halo_size=halo_size) | ||
|
||
return final_field, res.stats | ||
|
||
def run(): | ||
# Warm start | ||
chrono_fun = Timer() | ||
RangePush("warmup") | ||
final_field, stats = chrono_fun.chrono_jit(simulate, | ||
0.32, | ||
0.8, | ||
ndarray_arg=0) | ||
RangePop() | ||
sync_global_devices("warmup") | ||
for i in range(iterations): | ||
RangePush(f"sim iter {i}") | ||
final_field, stats = chrono_fun.chrono_fun(simulate, | ||
0.32, | ||
0.8, | ||
ndarray_arg=0) | ||
RangePop() | ||
return final_field, stats, chrono_fun | ||
|
||
if jax.device_count() > 1: | ||
devices = mesh_utils.create_device_mesh(pdims) | ||
mesh = Mesh(devices.T, axis_names=('x', 'y')) | ||
with mesh: | ||
# Warm start | ||
final_field, stats, chrono_fun = run() | ||
else: | ||
final_field, stats, chrono_fun = run() | ||
|
||
return final_field, stats, chrono_fun | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser( | ||
description='JAX Cosmo Simulation Benchmark') | ||
parser.add_argument('-m', | ||
'--mesh_size', | ||
type=int, | ||
help='Mesh size', | ||
required=True) | ||
parser.add_argument('-b', | ||
'--box_size', | ||
type=float, | ||
help='Box size', | ||
required=True) | ||
parser.add_argument('-p', | ||
'--pdims', | ||
type=str, | ||
help='Processor dimensions', | ||
default=None) | ||
parser.add_argument( | ||
'-pr', | ||
'--precision', | ||
type=str, | ||
help='Precision', | ||
choices=["float32", "float64"], | ||
) | ||
parser.add_argument('-hs', | ||
'--halo_size', | ||
type=int, | ||
help='Halo size', | ||
default=None) | ||
parser.add_argument('-s', | ||
'--solver', | ||
type=str, | ||
help='Solver', | ||
choices=[ | ||
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5", | ||
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm", | ||
"lpt" | ||
], | ||
default="lpt") | ||
parser.add_argument('-o', | ||
'--output_path', | ||
type=str, | ||
help='Output path', | ||
default=".") | ||
parser.add_argument('-f', | ||
'--save_fields', | ||
action='store_true', | ||
help='Save fields') | ||
parser.add_argument('-n', | ||
'--nodes', | ||
type=int, | ||
help='Number of nodes', | ||
default=1) | ||
args = parser.parse_args() | ||
mesh_size = args.mesh_size | ||
box_size = [args.box_size] * 3 | ||
halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size | ||
solver_choice = args.solver | ||
iterations = args.iterations | ||
output_path = args.output_path | ||
os.makedirs(output_path, exist_ok=True) | ||
|
||
print(f"solver choice: {solver_choice}") | ||
match solver_choice: | ||
case "Dopri5" | "dopri5" | "d5": | ||
solver_choice = "Dopri5" | ||
case "Tsit5" | "tsit5" | "t5": | ||
solver_choice = "Tsit5" | ||
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm": | ||
solver_choice = "LeapfrogMidpoint" | ||
case "lpt": | ||
solver_choice = "lpt" | ||
case _: | ||
raise ValueError( | ||
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" | ||
) | ||
if args.precision == "float32": | ||
jax.config.update("jax_enable_x64", False) | ||
elif args.precision == "float64": | ||
jax.config.update("jax_enable_x64", True) | ||
|
||
if args.pdims: | ||
pdims = tuple(map(int, args.pdims.split("x"))) | ||
else: | ||
pdims = (1, jax.device_count()) | ||
pdm_str = f"{pdims[0]}x{pdims[1]}" | ||
|
||
mesh_shape = [mesh_size] * 3 | ||
|
||
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, | ||
halo_size, solver_choice, | ||
iterations, pdims) | ||
|
||
print( | ||
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}" | ||
) | ||
|
||
metadata = { | ||
'rank': rank, | ||
'function_name': f'JAXPM-{solver_choice}', | ||
'precision': args.precision, | ||
'x': str(mesh_size), | ||
'y': str(mesh_size), | ||
'z': str(stats["num_steps"]), | ||
'px': str(pdims[0]), | ||
'py': str(pdims[1]), | ||
'backend': 'NCCL', | ||
'nodes': str(args.nodes) | ||
} | ||
# Print the results to a CSV file | ||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) | ||
|
||
# Save the final field | ||
nb_gpus = jax.device_count() | ||
pdm_str = f"{pdims[0]}x{pdims[1]}" | ||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" | ||
os.makedirs(field_folder, exist_ok=True) | ||
with open(f'{field_folder}/jaxpm.log', 'w') as f: | ||
f.write(f"Args: {args}\n") | ||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") | ||
for i, time in enumerate(chrono_fun.times): | ||
f.write(f"Time {i}: {time:.4f} ms\n") | ||
f.write(f"Stats: {stats}\n") | ||
if args.save_fields: | ||
np.save(f'{field_folder}/final_field_0_{rank}.npy', | ||
final_field.addressable_data(0)) | ||
|
||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" | ||
os.makedirs(field_folder, exist_ok=True) | ||
with open(f'{field_folder}/jaxpm.log', 'w') as f: | ||
f.write(f"Args: {args}\n") | ||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") | ||
for i, time in enumerate(chrono_fun.times): | ||
f.write(f"Time {i}: {time:.4f} ms\n") | ||
f.write(f"Stats: {stats}\n") | ||
if args.save_fields: | ||
np.save(f'{field_folder}/final_field_0_{rank}.npy', | ||
final_field.addressable_data(0)) | ||
|
||
print(f"Finished! ") | ||
print(f"Stats {stats}") | ||
print(f"Saving to {output_path}/jax_pm_benchmark.csv") | ||
print(f"Saving field and logs in {field_folder}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line should not be in the Goals section, it's a feature now.