From dc7cf023b08e4a65f2dbb61abdc7aebdcbec293f Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Wed, 3 Jan 2024 12:19:07 -0800 Subject: [PATCH 1/4] Fix initialization and microbatching for TPUs --- composer/trainer/trainer.py | 5 +++++ composer/utils/dist.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c8c6d325e0..5cbb9d043a 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2480,6 +2480,11 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, microbatch_loss.mul_(microbatch_num_samples / current_batch_size) microbatch_loss.backward(create_graph=self._backwards_create_graph) + if isinstance(self.state.device, DeviceTPU): + # For TPUs, the program between any pair of mark_steps() calls is compiled. With out this, the + # microbatching loop is unrolled, drastically increasing compile time. + xm.mark_step() + self.engine.run_event(Event.AFTER_BACKWARD) # Use microbatch outputs to update training metrics diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 1b59bff1d4..b470f4b4c0 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -519,7 +519,11 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0): dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items()) - if dist_env_vars_match_defaults: + if device_obj.name == 'tpu': + # TPU initialization requires the init_method to be set, so that it uses the special TPU + # initialization registered by pytorch xla. + dist.init_process_group(device_obj.dist_backend, init_method='xla://') + elif dist_env_vars_match_defaults: # Fill in the remaining single-rank variables os.environ.update(dist_env_var_defaults) dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0) From 23bc3c8d9308ad1a3f6f3ec35579697112bf901d Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 5 Feb 2024 10:21:48 -0800 Subject: [PATCH 2/4] TPU -> XLA --- composer/devices/device_tpu.py | 1 + composer/trainer/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/composer/devices/device_tpu.py b/composer/devices/device_tpu.py index b91d1bc478..813fc49924 100644 --- a/composer/devices/device_tpu.py +++ b/composer/devices/device_tpu.py @@ -26,6 +26,7 @@ class DeviceTPU(Device): More details. """ + dist_backend = 'xla' name = 'tpu' def __init__(self): diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 1e12862743..4493b6c18d 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2543,8 +2543,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, microbatch_loss.mul_(microbatch_num_samples / current_batch_size) microbatch_loss.backward(create_graph=self._backwards_create_graph) - if isinstance(self.state.device, DeviceTPU): - # For TPUs, the program between any pair of mark_steps() calls is compiled. With out this, the + if self.state.device.dist_backend == 'xla'): + # For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the # microbatching loop is unrolled, drastically increasing compile time. xm.mark_step() From 32c5e58e10b3f5ef3261429cd57f3a6214dcab45 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 5 Feb 2024 10:23:19 -0800 Subject: [PATCH 3/4] TPU -> XLA --- composer/trainer/trainer.py | 2 +- composer/utils/dist.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4493b6c18d..4a8b5f14aa 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2543,7 +2543,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, microbatch_loss.mul_(microbatch_num_samples / current_batch_size) microbatch_loss.backward(create_graph=self._backwards_create_graph) - if self.state.device.dist_backend == 'xla'): + if self.state.device.dist_backend == 'xla': # For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the # microbatching loop is unrolled, drastically increasing compile time. xm.mark_step() diff --git a/composer/utils/dist.py b/composer/utils/dist.py index b470f4b4c0..399d73e6bd 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -519,9 +519,8 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0): dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items()) - if device_obj.name == 'tpu': - # TPU initialization requires the init_method to be set, so that it uses the special TPU - # initialization registered by pytorch xla. + if device_obj.dist_backend == 'xla': + # XLA initialization requires the init_method to be set dist.init_process_group(device_obj.dist_backend, init_method='xla://') elif dist_env_vars_match_defaults: # Fill in the remaining single-rank variables From b19f23b4d62c34a993dd461cdf077f803ee54911 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Wed, 14 Feb 2024 09:13:44 -0800 Subject: [PATCH 4/4] add version check for PyTortch XLA >= 2.1 --- composer/utils/dist.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 399d73e6bd..4913b12ba1 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -37,6 +37,7 @@ import logging import os import pickle +import sys import time from contextlib import contextmanager from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast @@ -44,8 +45,12 @@ import torch import torch.distributed as dist import torch.utils.data +from packaging import version -from composer.utils.device import get_device, is_hpu_installed +from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed + +if is_tpu_installed(): + import torch_xla if TYPE_CHECKING: from composer.devices import Device @@ -520,6 +525,11 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0): dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items()) if device_obj.dist_backend == 'xla': + if not 'torch_xla' in sys.modules: + raise RuntimeError('PyTorch XLA package not found. In order to use XLA based devices ' + 'PyTorch XLA must be installed.') + if version.parse(torch_xla.__version__) < version.parse('2.1.0'): + raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.') # XLA initialization requires the init_method to be set dist.init_process_group(device_obj.dist_backend, init_method='xla://') elif dist_env_vars_match_defaults: