[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1184
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
original PR #1180
If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will
invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync
is True (e.g. last microbatch)
Trace comparison
trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module):
https://fburl.com/perfdoctor/qdt32ibh
trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia
numerics verification
local run with deterministic mode
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1363180533/
test
https://www.internalfb.com/intern/paste/P1363177870/
TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)
baseline:
https://www.internalfb.com/intern/paste/P1322976356/
test :
https://www.internalfb.com/intern/paste/P1322871976/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358660231/
test
https://www.internalfb.com/intern/paste/P1358659328/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358780690
test
https://www.internalfb.com/intern/paste/P1358786994/
E2E MAST tests:
model = small, TP = 2, PP = 2, DP = 2 (loss on par)
baseline:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd
test:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966
Perf evaluation
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8
baseline:
e2e TFLOPS/s: 339.53
comp TFLOPS/s: 625.64
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia
test:
e2e TFLOPS/s: 387.98 (~15%)
comp TFLOPS/s: 817.5 (~30%)
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia