diff --git a/tests/algorithms/test_algorithm_resumption.py b/tests/algorithms/test_algorithm_resumption.py index c00fdf2ade..f1c464e662 100644 --- a/tests/algorithms/test_algorithm_resumption.py +++ b/tests/algorithms/test_algorithm_resumption.py @@ -28,88 +28,99 @@ def test_algorithm_resumption( alg_cls: type[Algorithm], world_size, ): - folder1 = os.path.join(tmp_path, 'folder1') - folder2 = os.path.join(tmp_path, 'folder2') - os.makedirs(folder1, exist_ok=True) - os.makedirs(folder2, exist_ok=True) - - model = get_alg_model(alg_cls) - alg_kwargs = get_alg_kwargs(alg_cls) - - copied_model = copy.deepcopy(model) # copy the model so the params will start from the same point - - if alg_cls is LayerFreezing: - pytest.xfail('Known issues') - - if alg_cls in (SAM, StochasticDepth): - pytest.xfail('Mismatch in weights when resuming from a checkpoint.') - - if alg_cls is GyroDropout: - pytest.xfail('GyroDropoutLayer is not implemented in a way that allows correct resumption.') - - if alg_cls is SWA and world_size > 1: - pytest.xfail('SWA is not implemented in a way that is compatible correct resumption on multiple devices.') - - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5) - - shared_config = { - 'max_duration': '2ba', - 'save_filename': 'ba{batch}-rank{rank}', - 'save_interval': '1ba', - 'train_subset_num_batches': 2, - 'precision': 'amp_bf16', - } - train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) - # train model once, saving checkpoints every batch - trainer1 = Trainer( - model=model, - train_dataloader=train_dataloader, - optimizers=optimizer, - schedulers=scheduler, - save_folder=folder1, - algorithms=alg_cls(**alg_kwargs), - **shared_config, - ) - trainer1.fit() - - # create second trainer, load from the first batch checkpoint, and continue training - optimizer = torch.optim.Adam(copied_model.parameters(), lr=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5) - - alg = alg_cls(**alg_kwargs) - if alg_cls is SeqLengthWarmup: - alg._activated = True # type: ignore - - train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) - trainer2 = Trainer( - model=copied_model, - train_dataloader=train_dataloader, - load_path=os.path.join(folder1, 'ba1-rank{rank}'), - load_weights_only=False, - load_strict_model_weights=False, - optimizers=optimizer, - schedulers=scheduler, - save_folder=folder2, - algorithms=alg, - **shared_config, - ) - trainer2.fit() - # check that the checkpoints after the second batch are equal - if world_size == 1 or dist.get_global_rank() == 0: - _assert_checkpoints_equal( - file1=os.path.join(folder1, 'ba2-rank0'), - file2=os.path.join(folder2, 'ba2-rank0'), + # Use RAM-based tmp directory instead of disk + from tempfile import TemporaryDirectory + with TemporaryDirectory() as tmpdir: + folder1 = os.path.join(tmpdir, 'folder1') + folder2 = os.path.join(tmpdir, 'folder2') + os.makedirs(folder1, exist_ok=True) + os.makedirs(folder2, exist_ok=True) + + if alg_cls is LayerFreezing: + pytest.xfail('Known issues') + + if alg_cls in (SAM, StochasticDepth): + pytest.xfail('Mismatch in weights when resuming from a checkpoint.') + + if alg_cls is GyroDropout: + pytest.xfail('GyroDropoutLayer is not implemented in a way that allows correct resumption.') + + if alg_cls is SWA and world_size > 1: + pytest.xfail('SWA is not implemented in a way that is compatible correct resumption on multiple devices.') + + model = get_alg_model(alg_cls) + alg_kwargs = get_alg_kwargs(alg_cls) + + copied_model = copy.deepcopy(model) # copy the model so the params will start from the same point + + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + + # Reduce training duration and data + shared_config = { + 'max_duration': '2ba', + 'save_filename': 'ba{batch}-rank{rank}', + 'save_interval': '1ba', + 'train_subset_num_batches': 2, + 'precision': 'fp32', + } + train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) + # train model once, saving checkpoints every epoch + trainer1 = Trainer( + model=model, + train_dataloader=train_dataloader, + optimizers=optimizer, + schedulers=scheduler, + save_folder=folder1, + algorithms=alg_cls(**alg_kwargs), + **shared_config, ) + trainer1.fit() + + # create second trainer, load an intermediate checkpoint + # and continue training + + optimizer = torch.optim.SGD(copied_model.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + + alg = alg_cls(**alg_kwargs) + # SeqLengthWarmup has a call to ._activate_model() that happens on the first call to the algorithm + # in order to get complete matching of the rng state, we have to cause that extra call to be skipped + # when reloading. + if alg_cls is SeqLengthWarmup: + alg._activated = True # type: ignore + train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) + + trainer2 = Trainer( + model=copied_model, + train_dataloader=train_dataloader, + load_path=os.path.join(folder1, 'ep1-rank{rank}'), + load_weights_only=False, + load_strict_model_weights=False, + optimizers=optimizer, + schedulers=scheduler, + save_folder=folder2, + algorithms=alg, + **shared_config, + ) + trainer2.fit() - # ensure that the first and second batch checkpoints are not equal - if world_size == 1 or dist.get_global_rank() == 0: - with pytest.raises(AssertionError): - _assert_model_weights_equal( - file1=os.path.join(folder1, 'ba1-rank0'), - file2=os.path.join(folder1, 'ba2-rank0'), + # check that the checkpoints are equal + if world_size == 1 or dist.get_global_rank() == 0: + _assert_checkpoints_equal( + file1=os.path.join(folder1, 'ep2-rank0'), + file2=os.path.join(folder2, 'ep2-rank0'), ) + # check that different epoch checkpoints are _not_ equal + # this ensures that the model weights are being updated. + if world_size == 1 or dist.get_global_rank() == 0: + with pytest.raises(AssertionError): + _assert_model_weights_equal( + file1=os.path.join(folder1, 'ep1-rank0'), + file2=os.path.join(folder1, 'ep2-rank0'), + ) + def _assert_checkpoints_equal(file1, file2): # TODO: consider merging with _assert_checkpoints_equivalent