Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1184

Conversation

chrisxcai
Copy link

original PR #1180

If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will
invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync
is True (e.g. last microbatch)

Trace comparison

trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module):
https://fburl.com/perfdoctor/qdt32ibh

trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia

numerics verification

local run with deterministic mode

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)

baseline

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=False

https://www.internalfb.com/intern/paste/P1363180533/

test

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=True

https://www.internalfb.com/intern/paste/P1363177870/

TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)

baseline:
https://www.internalfb.com/intern/paste/P1322976356/
test :
https://www.internalfb.com/intern/paste/P1322871976/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)

baseline
https://www.internalfb.com/intern/paste/P1358660231/

test
https://www.internalfb.com/intern/paste/P1358659328/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)

baseline
https://www.internalfb.com/intern/paste/P1358780690

test
https://www.internalfb.com/intern/paste/P1358786994/

E2E MAST tests:

model = small, TP = 2, PP = 2, DP = 2 (loss on par)

baseline:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd

test:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966

loss

Perf evaluation

model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8

baseline:
e2e TFLOPS/s: 339.53
comp TFLOPS/s: 625.64

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia

test:
e2e TFLOPS/s: 387.98 (~15%)
comp TFLOPS/s: 817.5 (~30%)

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 15, 2024
@chrisxcai chrisxcai marked this pull request as ready for review May 15, 2024 22:51
@chrisxcai chrisxcai requested a review from awgu May 15, 2024 22:51
@chrisxcai chrisxcai merged commit 9cbb4a7 into ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard May 16, 2024
1 of 18 checks passed
param_index,
):
if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been a while to look at FSDP code so can be a silly question, but wonder for TransformerEngine modules that keep their weight grads in .main_grad in fp32 precision this can be a duplication.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fp32 weight grads being in .main_grad would be unexpected to me. My understanding was that we are not using that functionality and instead relying on FSDP to do the fp32 gradient accumulation since getting the weights' .main_grad to backprop correctly into the FSDP FlatParameter is tricky.

I think that in the current approach, .main_grad should only be written to after the reduce-scatter:

param.main_grad = reduced_grad.data

However, since the memory usage is higher, maybe somehow there is indeed some duplication with this approach.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I just checked we indeed don't use main_grad: https://github.com/fairinternal/xlformers/blob/main/src/model/te_layers.py#L104

I see https://github.com/fairinternal/xlformers/pull/1418 and #1142 were not merged but something we may want to consider to avoid gradient accumulation overhead.

@chrisxcai chrisxcai deleted the chriscai_ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard branch June 9, 2024 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants