Skip to content

Commit

Permalink
Add initial support for neuron devices (#3049)
Browse files Browse the repository at this point in the history
* add initial support for neuron devices

* add missing device_neuron.py

* remove unused imports

* update neuron flops

* formatting

* formatting

* formatting

* documentation

* formatting
  • Loading branch information
bfontain authored and Chuck Tang committed May 16, 2024
1 parent 33db788 commit 29c21bd
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 38 deletions.
55 changes: 38 additions & 17 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from composer.core import Callback, State
from composer.loggers import Logger
from composer.models.base import ComposerModel
from composer.utils import dist
from composer.utils import dist, is_xla_installed

if is_xla_installed():
import torch_xla.core.xla_model as xm

__all__ = ['SpeedMonitor']

Expand Down Expand Up @@ -83,31 +86,49 @@
'int8': 130e12,
'int4': 260e12,
},
# source: https://aws.amazon.com/blogs/machine-learning/aws-inferentia2-builds-on-aws-inferentia1-by-delivering-4x-higher-throughput-and-10x-lower-latency/
# Numbers are halved as the above flops is per chip and each chip appears as 2 devices.
'trn1': {
'fp32': 47.5e12 / 2,
'tf32': 47.5e12 / 2,
'fp16': 190e12 / 2,
'amp_fp16': 190e12 / 2,
'bf16': 190e12 / 2,
'amp_bf16': 190e12 / 2,
'int8': 380e12 / 2,
}
}


def get_gpu_flops_available(state: State):
gpu_flops_available = None

# Return 0 if no CUDA device (e.g., when running with CPU only)
if not torch.cuda.is_available():
if torch.cuda.is_available():
# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
device_name = torch.cuda.get_device_name().lower()
if 'h100' in device_name and 'hbm3' in device_name:
device_name = 'h100-sxm'
elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name):
device_name = 'h100-pcie'
elif 'a100' in device_name:
device_name = 'a100'
elif 'v100-sxm' in device_name:
device_name = 'v100-sxm'
elif 'v100-pcie' in device_name:
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'
elif is_xla_installed():
if xm.xla_device_hw(xm.xla_device()) == 'NEURON':
device_name = 'trn1'
else:
# For TPU return 0
return 0
else:
# When running on CPU, return 0 without warning
return 0

# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
device_name = torch.cuda.get_device_name().lower()
if 'h100' in device_name and 'hbm3' in device_name:
device_name = 'h100-sxm'
elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name):
device_name = 'h100-pcie'
elif 'a100' in device_name:
device_name = 'a100'
elif 'v100-sxm' in device_name:
device_name = 'v100-sxm'
elif 'v100-pcie' in device_name:
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'

if device_name in GPU_AVAILABLE_FLOPS and state.precision.value in GPU_AVAILABLE_FLOPS[device_name]:
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
else:
Expand Down
15 changes: 11 additions & 4 deletions composer/core/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
"""Enum class for the numerical precision to be used by the model."""

import contextlib
import os
import textwrap
from typing import Any, Dict, Generator, Optional, Union

import torch

from composer.utils import StringEnum
from composer.utils import StringEnum, is_xla_installed

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -57,14 +56,22 @@ def get_precision_context(precision: Union[str, Precision],
yield
elif precision == Precision.AMP_FP16:
# Retain compatibility with PyTorch < 1.10
with torch.cuda.amp.autocast(True):
if torch.cuda.is_available():
with torch.cuda.amp.autocast(True):
yield
elif is_xla_installed():
with torch.autocast('xla', dtype=torch.float16):
yield
else:
yield
elif precision == Precision.AMP_BF16:
if torch.cuda.is_available():
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
yield
elif is_xla_installed():
with torch.autocast('xla', dtype=torch.bfloat16):
yield
else:
os.environ['XLA_USE_BF16'] = '1'
yield
elif precision == Precision.AMP_FP8:
if te_installed and torch.cuda.get_device_capability() >= (8, 9):
Expand Down
3 changes: 2 additions & 1 deletion composer/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from composer.devices.device_gpu import DeviceGPU
from composer.devices.device_hpu import DeviceHPU
from composer.devices.device_mps import DeviceMPS
from composer.devices.device_neuron import DeviceNeuron
from composer.devices.device_tpu import DeviceTPU

__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceTPU', 'DeviceHPU']
__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceNeuron', 'DeviceTPU', 'DeviceHPU']
52 changes: 52 additions & 0 deletions composer/devices/device_neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""The Neuron device used for training."""

from __future__ import annotations

import logging
import os
from typing import Any, Dict, TypeVar

import torch

from composer.devices.device import Device

logger = logging.getLogger(__name__)

__all__ = ['DeviceNeuron']

T_nnModule = TypeVar('T_nnModule', bound=torch.nn.Module)


class DeviceNeuron(Device):
"""An extension of :class:`~composer.devices.device.Device` for Neuron devices (Trn, Inf).
When running on Trn, we automatically set `export PJRT_DEVICE=NEURON`.
"""

name = 'neuron'
dist_backend = 'xla'

def __init__(self):
import torch_xla.core.xla_model as xm

# Turn off compiler based mixed precision (we use torch's amp support)
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/appnotes/neuronx-cc/neuronx-cc-training-mixed-precision.html
os.environ['NEURON_CC_FLAGS'] = '--auto-cast=none'
os.environ['PJRT_DEVICE'] = 'NEURON'
self._device = xm.xla_device()

def module_to_device(self, module: T_nnModule) -> T_nnModule:
return module.to(self._device)

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> Dict[str, Any]:
return {}

def load_state_dict(self, state: Dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('Neuron device has no state.')
13 changes: 5 additions & 8 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
set_fsdp_default)
from composer.utils import (ExportFormat, MissingConditionalImportError, ObjectStore, Transform, checkpoint, dist,
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
get_composer_env_dict, get_device, get_file, is_xla_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
model_eval_mode, parse_uri, partial_format, reproducibility)
from composer.utils.misc import is_model_deepspeed
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

if is_tpu_installed():
if is_xla_installed():
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

Expand Down Expand Up @@ -1320,7 +1320,7 @@ def __init__(
if self._train_data_spec is not None:
self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label,
train_subset_num_batches)
if isinstance(self.state.device, DeviceTPU):
if self.state.device.dist_backend == 'xla':
self.state.train_dataloader = pl.MpDeviceLoader(self.state.dataloader, xm.xla_device())
else:
self.state.train_dataloader = self.state.dataloader
Expand Down Expand Up @@ -2342,10 +2342,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
if use_grad_scaling:
self.state.scaler.step(optimizer)
else:
if isinstance(self.state.device, DeviceTPU):
xm.optimizer_step(optimizer, barrier=True)
else:
optimizer.step()
optimizer.step()
except RuntimeError as e:
if self.state.auto_microbatching and _is_cuda_oom(e):
log.debug((f"Rank {dist.get_global_rank()} OOM'd."))
Expand Down Expand Up @@ -3162,7 +3159,7 @@ def _use_closures(self) -> bool:
if self.state.deepspeed_enabled:
return False

if isinstance(self.state.device, DeviceTPU):
if self.state.device.dist_backend == 'xla':
return False

if self.state.precision != Precision.AMP_FP16:
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
save_checkpoint)
from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report,
get_composer_env_dict, print_env)
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_xla_installed
from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE,
create_symlink_file, ensure_folder_has_no_conflicting_files,
Expand Down Expand Up @@ -78,7 +78,7 @@
'retry',
'model_eval_mode',
'get_device',
'is_tpu_installed',
'is_xla_installed',
'is_hpu_installed',
'ExportFormat',
'Transform',
Expand Down
21 changes: 17 additions & 4 deletions composer/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
if TYPE_CHECKING:
from composer.devices import Device

__all__ = ['get_device', 'is_tpu_installed', 'is_hpu_installed']
__all__ = ['get_device', 'is_hpu_installed', 'is_xla_installed']

_is_xla_installed = None


def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
Expand All @@ -25,7 +27,7 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
Device. If no argument is passed, returns :class:`.DeviceGPU` if available,
or :class:`.DeviceCPU` if no GPU is available.
"""
from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceTPU
from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceNeuron, DeviceTPU

if not device:
device = DeviceGPU() if torch.cuda.is_available() else DeviceCPU()
Expand All @@ -37,11 +39,17 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
elif device.lower() == 'mps':
device = DeviceMPS()
elif device.lower() == 'tpu':
if not is_tpu_installed():
if not is_xla_installed():
raise ImportError(
'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla'
)
device = DeviceTPU()
elif device.lower() == 'neuron':
if not is_xla_installed():
raise ImportError(
'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla'
)
device = DeviceNeuron()
elif device.lower() == 'hpu':
if not is_hpu_installed():
raise ImportError('Unable to import habana-torch-plugin.')
Expand All @@ -51,17 +59,22 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
return device


def is_tpu_installed() -> bool:
def is_xla_installed() -> bool:
"""Determines whether the module needed for training on TPUs—torch_xla—is installed.
Returns:
bool: Whether torch_xla is installed.
"""
global _is_xla_installed
if _is_xla_installed:
return _is_xla_installed
try:
import torch_xla
del torch_xla
_is_xla_installed = True
return True
except ModuleNotFoundError:
_is_xla_installed = False
return False


Expand Down
4 changes: 2 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_xla_installed

if is_tpu_installed():
if is_xla_installed():
import torch_xla

if TYPE_CHECKING:
Expand Down

0 comments on commit 29c21bd

Please sign in to comment.