From ac40452c85fae2f50ec3c627292c267260969017 Mon Sep 17 00:00:00 2001 From: yisheng Date: Sun, 25 Aug 2024 20:20:00 -0700 Subject: [PATCH] enable empty cache on XPU device --- megatron/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 6ba26f3944..9bd68bc90a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -724,7 +724,7 @@ def train_step(forward_step_func, data_iterator, # Empty unused memory. if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() + get_accelerator().empty_cache() # Reduce gradients. if not args.deepspeed: @@ -781,7 +781,7 @@ def train_step(forward_step_func, data_iterator, # Empty unused memory. if args.empty_unused_memory_level >= 2: - torch.cuda.empty_cache() + get_accelerator().empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. @@ -1437,7 +1437,7 @@ def evaluate(forward_step_func, # Empty unused memory if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() + get_accelerator().empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Reduce across processes.