Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxdecomp proto #21

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a742065
adding example of distributed solution
EiffL Jul 9, 2024
6408aff
put back old functgion
EiffL Jul 9, 2024
319942a
update formatting
EiffL Jul 9, 2024
ac86468
add halo exchange and slice pad
ASKabalan Jul 9, 2024
e62cd84
apply formatting
ASKabalan Jul 18, 2024
5775a37
implement distributed optimized cic_paint
ASKabalan Jul 18, 2024
7501b5b
Use new cic_paint with halo
ASKabalan Jul 18, 2024
7f48cfa
Fix seed for distributed normal
ASKabalan Jul 18, 2024
c81d4d2
Wrap interpolation function to avoid all gather
ASKabalan Jul 18, 2024
abde543
Return normal order frequencies for single GPU
ASKabalan Jul 18, 2024
82be568
add example
ASKabalan Jul 18, 2024
4f508b7
format
ASKabalan Jul 18, 2024
5f6d42e
add optimised bench script
ASKabalan Jul 18, 2024
1f6b9c3
times in ms
ASKabalan Jul 18, 2024
ed8cf8e
add lpt2
ASKabalan Jul 18, 2024
0216837
update benchmark and add slurm
ASKabalan Jul 18, 2024
5b7f595
Visualize only final field
ASKabalan Jul 18, 2024
1f20351
Update scripts/distributed_pm.py
ASKabalan Jul 19, 2024
f25eb7d
Adjust pencil type for frequencies
ASKabalan Jul 28, 2024
8c5bd76
fix painting issue with slabs
ASKabalan Aug 2, 2024
75604d2
Shared operation in fourrier space now take inverted sharding axis for
ASKabalan Aug 2, 2024
ccbfee3
add assert to make pyright happy
ASKabalan Aug 2, 2024
aebc3e7
adjust test for hpc-plotter
ASKabalan Aug 2, 2024
9af4659
add PMWD test
ASKabalan Aug 2, 2024
831291c
bench
Aug 2, 2024
ece8c93
format
ASKabalan Aug 2, 2024
783a974
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Aug 2, 2024
02754cf
added github workflow
ASKabalan Aug 2, 2024
8da3149
fix formatting from main
ASKabalan Aug 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Code Formatting

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule
celerybeat.pid


out
traces
*.npy
*.out
# SageMath parsed files
*.sage.py

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
261 changes: 261 additions & 0 deletions benchmarks/bench_pm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import os

os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax

jax.distributed.initialize()

rank = jax.process_index()
size = jax.process_count()

import argparse
import time

import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn


def run_simulation(mesh_shape,
box_size,
halo_size,
solver_choice,
iterations,
pdims=None):

@jax.jit
def simulate(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)

# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0))

# Create particles
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
if solver_choice == "Dopri5":
solver = Dopri5()
elif solver_choice == "LeapfrogMidpoint":
solver = LeapfrogMidpoint()
elif solver_choice == "Tsit5":
solver = Tsit5()
elif solver_choice == "lpt":
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
print(f"TYPE of lpt_field: {type(lpt_field)}")
return lpt_field, {"num_steps": 0}
else:
raise ValueError(
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))

if solver_choice == "Dopri5" or solver_choice == "Tsit5":
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(t1=True),
stepsize_controller=stepsize_controller)

# Return the simulation volume at requested
state = res.ys[-1]
final_field = cic_paint_dx(state[0], halo_size=halo_size)

return final_field, res.stats

def run():
# Warm start
chrono_fun = Timer()
RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
return final_field, stats, chrono_fun

if jax.device_count() > 1:
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
with mesh:
# Warm start
final_field, stats, chrono_fun = run()
else:
final_field, stats, chrono_fun = run()

return final_field, stats, chrono_fun


if __name__ == "__main__":

parser = argparse.ArgumentParser(
description='JAX Cosmo Simulation Benchmark')
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-p',
'--pdims',
type=str,
help='Processor dimensions',
default=None)
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],
)
parser.add_argument('-hs',
'--halo_size',
type=int,
help='Halo size',
required=True)
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=[
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
"lpt"
],
required=True)
parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-n',
'--nodes',
type=int,
help='Number of nodes',
default=1)

args = parser.parse_args()
mesh_size = args.mesh_size
box_size = [args.box_size] * 3
halo_size = args.halo_size
solver_choice = args.solver
iterations = args.iterations
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)

print(f"solver choice: {solver_choice}")
match solver_choice:
case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5"
case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
case _:
raise ValueError(
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
)
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)

if args.pdims:
pdims = tuple(map(int, args.pdims.split("x")))
else:
pdims = (1, 1)

mesh_shape = [mesh_size] * 3

final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
iterations, pdims)

print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)

metadata = {
'rank': rank,
'function_name': f'JAXPM-{solver_choice}',
'precision': args.precision,
'x': str(mesh_size),
'y': str(mesh_size),
'z': str(stats["num_steps"]),
'px': str(pdims[0]),
'py': str(pdims[1]),
'backend': 'NCCL',
'nodes': str(args.nodes)
}
# Print the results to a CSV file
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)

# Save the final field
nb_gpus = jax.device_count()
pdm_str = f"{pdims[0]}x{pdims[1]}"
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))

print(f"Finished! ")
print(f"Stats {stats}")
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
print(f"Saving field and logs in {field_folder}")
Loading
Loading