diff --git a/composer/devices/device_tpu.py b/composer/devices/device_tpu.py index b91d1bc478..813fc49924 100644 --- a/composer/devices/device_tpu.py +++ b/composer/devices/device_tpu.py @@ -26,6 +26,7 @@ class DeviceTPU(Device): More details. """ + dist_backend = 'xla' name = 'tpu' def __init__(self): diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7411dc4393..0d2349bf93 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2567,6 +2567,11 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, microbatch_loss.mul_(microbatch_num_samples / current_batch_size) microbatch_loss.backward(create_graph=self._backwards_create_graph) + if self.state.device.dist_backend == 'xla': + # For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the + # microbatching loop is unrolled, drastically increasing compile time. + xm.mark_step() + self.engine.run_event(Event.AFTER_BACKWARD) # Use microbatch outputs to update training metrics diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 65edb5e80c..5b8dd5df68 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -37,6 +37,7 @@ import logging import os import pickle +import sys import time from contextlib import contextmanager from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast @@ -44,8 +45,12 @@ import torch import torch.distributed as dist import torch.utils.data +from packaging import version -from composer.utils.device import get_device, is_hpu_installed +from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed + +if is_tpu_installed(): + import torch_xla if TYPE_CHECKING: from composer.devices import Device @@ -534,7 +539,15 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0): dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items()) - if dist_env_vars_match_defaults: + if device_obj.dist_backend == 'xla': + if not 'torch_xla' in sys.modules: + raise RuntimeError('PyTorch XLA package not found. In order to use XLA based devices ' + 'PyTorch XLA must be installed.') + if version.parse(torch_xla.__version__) < version.parse('2.1.0'): + raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.') + # XLA initialization requires the init_method to be set + dist.init_process_group(device_obj.dist_backend, init_method='xla://') + elif dist_env_vars_match_defaults: # Fill in the remaining single-rank variables os.environ.update(dist_env_var_defaults) dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)