Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update XLA support #2964

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DeviceTPU(Device):
More details.
"""

dist_backend = 'xla'
name = 'tpu'

def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,11 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
microbatch_loss.mul_(microbatch_num_samples / current_batch_size)
microbatch_loss.backward(create_graph=self._backwards_create_graph)

if self.state.device.dist_backend == 'xla':
# For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the
# microbatching loop is unrolled, drastically increasing compile time.
xm.mark_step()

self.engine.run_event(Event.AFTER_BACKWARD)

# Use microbatch outputs to update training metrics
Expand Down
17 changes: 15 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@
import logging
import os
import pickle
import sys
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast

import torch
import torch.distributed as dist
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed

if is_tpu_installed():
import torch_xla

if TYPE_CHECKING:
from composer.devices import Device
Expand Down Expand Up @@ -534,7 +539,15 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0):

dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items())

if dist_env_vars_match_defaults:
if device_obj.dist_backend == 'xla':
if not 'torch_xla' in sys.modules:
raise RuntimeError('PyTorch XLA package not found. In order to use XLA based devices '
'PyTorch XLA must be installed.')
if version.parse(torch_xla.__version__) < version.parse('2.1.0'):
raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.')
# XLA initialization requires the init_method to be set
dist.init_process_group(device_obj.dist_backend, init_method='xla://')
elif dist_env_vars_match_defaults:
# Fill in the remaining single-rank variables
os.environ.update(dist_env_var_defaults)
dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)
Expand Down
Loading