From 29c21bd4976c04299dfd223640d52025bd656ee0 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Thu, 29 Feb 2024 10:53:43 -0800 Subject: [PATCH] Add initial support for neuron devices (#3049) * add initial support for neuron devices * add missing device_neuron.py * remove unused imports * update neuron flops * formatting * formatting * formatting * documentation * formatting --- composer/callbacks/speed_monitor.py | 55 ++++++++++++++++++++--------- composer/core/precision.py | 15 +++++--- composer/devices/__init__.py | 3 +- composer/devices/device_neuron.py | 52 +++++++++++++++++++++++++++ composer/trainer/trainer.py | 13 +++---- composer/utils/__init__.py | 4 +-- composer/utils/device.py | 21 ++++++++--- composer/utils/dist.py | 4 +-- 8 files changed, 129 insertions(+), 38 deletions(-) create mode 100644 composer/devices/device_neuron.py diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index e574b8e713..98478225ba 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -13,7 +13,10 @@ from composer.core import Callback, State from composer.loggers import Logger from composer.models.base import ComposerModel -from composer.utils import dist +from composer.utils import dist, is_xla_installed + +if is_xla_installed(): + import torch_xla.core.xla_model as xm __all__ = ['SpeedMonitor'] @@ -83,6 +86,17 @@ 'int8': 130e12, 'int4': 260e12, }, + # source: https://aws.amazon.com/blogs/machine-learning/aws-inferentia2-builds-on-aws-inferentia1-by-delivering-4x-higher-throughput-and-10x-lower-latency/ + # Numbers are halved as the above flops is per chip and each chip appears as 2 devices. + 'trn1': { + 'fp32': 47.5e12 / 2, + 'tf32': 47.5e12 / 2, + 'fp16': 190e12 / 2, + 'amp_fp16': 190e12 / 2, + 'bf16': 190e12 / 2, + 'amp_bf16': 190e12 / 2, + 'int8': 380e12 / 2, + } } @@ -90,24 +104,31 @@ def get_gpu_flops_available(state: State): gpu_flops_available = None # Return 0 if no CUDA device (e.g., when running with CPU only) - if not torch.cuda.is_available(): + if torch.cuda.is_available(): + # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB' + device_name = torch.cuda.get_device_name().lower() + if 'h100' in device_name and 'hbm3' in device_name: + device_name = 'h100-sxm' + elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name): + device_name = 'h100-pcie' + elif 'a100' in device_name: + device_name = 'a100' + elif 'v100-sxm' in device_name: + device_name = 'v100-sxm' + elif 'v100-pcie' in device_name: + device_name = 'v100-pcie' + elif 't4' in device_name: + device_name = 't4' + elif is_xla_installed(): + if xm.xla_device_hw(xm.xla_device()) == 'NEURON': + device_name = 'trn1' + else: + # For TPU return 0 + return 0 + else: + # When running on CPU, return 0 without warning return 0 - # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB' - device_name = torch.cuda.get_device_name().lower() - if 'h100' in device_name and 'hbm3' in device_name: - device_name = 'h100-sxm' - elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name): - device_name = 'h100-pcie' - elif 'a100' in device_name: - device_name = 'a100' - elif 'v100-sxm' in device_name: - device_name = 'v100-sxm' - elif 'v100-pcie' in device_name: - device_name = 'v100-pcie' - elif 't4' in device_name: - device_name = 't4' - if device_name in GPU_AVAILABLE_FLOPS and state.precision.value in GPU_AVAILABLE_FLOPS[device_name]: gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value]) else: diff --git a/composer/core/precision.py b/composer/core/precision.py index 4e641a07e9..e3ea05d67d 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -4,13 +4,12 @@ """Enum class for the numerical precision to be used by the model.""" import contextlib -import os import textwrap from typing import Any, Dict, Generator, Optional, Union import torch -from composer.utils import StringEnum +from composer.utils import StringEnum, is_xla_installed try: import transformer_engine.pytorch as te @@ -57,14 +56,22 @@ def get_precision_context(precision: Union[str, Precision], yield elif precision == Precision.AMP_FP16: # Retain compatibility with PyTorch < 1.10 - with torch.cuda.amp.autocast(True): + if torch.cuda.is_available(): + with torch.cuda.amp.autocast(True): + yield + elif is_xla_installed(): + with torch.autocast('xla', dtype=torch.float16): + yield + else: yield elif precision == Precision.AMP_BF16: if torch.cuda.is_available(): with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): yield + elif is_xla_installed(): + with torch.autocast('xla', dtype=torch.bfloat16): + yield else: - os.environ['XLA_USE_BF16'] = '1' yield elif precision == Precision.AMP_FP8: if te_installed and torch.cuda.get_device_capability() >= (8, 9): diff --git a/composer/devices/__init__.py b/composer/devices/__init__.py index d3cba9d37a..8ed07936b3 100644 --- a/composer/devices/__init__.py +++ b/composer/devices/__init__.py @@ -8,6 +8,7 @@ from composer.devices.device_gpu import DeviceGPU from composer.devices.device_hpu import DeviceHPU from composer.devices.device_mps import DeviceMPS +from composer.devices.device_neuron import DeviceNeuron from composer.devices.device_tpu import DeviceTPU -__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceTPU', 'DeviceHPU'] +__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceNeuron', 'DeviceTPU', 'DeviceHPU'] diff --git a/composer/devices/device_neuron.py b/composer/devices/device_neuron.py new file mode 100644 index 0000000000..840d5985ef --- /dev/null +++ b/composer/devices/device_neuron.py @@ -0,0 +1,52 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""The Neuron device used for training.""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, TypeVar + +import torch + +from composer.devices.device import Device + +logger = logging.getLogger(__name__) + +__all__ = ['DeviceNeuron'] + +T_nnModule = TypeVar('T_nnModule', bound=torch.nn.Module) + + +class DeviceNeuron(Device): + """An extension of :class:`~composer.devices.device.Device` for Neuron devices (Trn, Inf). + + When running on Trn, we automatically set `export PJRT_DEVICE=NEURON`. + """ + + name = 'neuron' + dist_backend = 'xla' + + def __init__(self): + import torch_xla.core.xla_model as xm + + # Turn off compiler based mixed precision (we use torch's amp support) + # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/appnotes/neuronx-cc/neuronx-cc-training-mixed-precision.html + os.environ['NEURON_CC_FLAGS'] = '--auto-cast=none' + os.environ['PJRT_DEVICE'] = 'NEURON' + self._device = xm.xla_device() + + def module_to_device(self, module: T_nnModule) -> T_nnModule: + return module.to(self._device) + + def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self._device) + + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + if len(state) != 0: + raise ValueError('Neuron device has no state.') diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0d2349bf93..4932c5dea6 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -56,13 +56,13 @@ set_fsdp_default) from composer.utils import (ExportFormat, MissingConditionalImportError, ObjectStore, Transform, checkpoint, dist, ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist, - get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection, + get_composer_env_dict, get_device, get_file, is_xla_installed, map_collection, maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri, model_eval_mode, parse_uri, partial_format, reproducibility) from composer.utils.misc import is_model_deepspeed from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY -if is_tpu_installed(): +if is_xla_installed(): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl @@ -1320,7 +1320,7 @@ def __init__( if self._train_data_spec is not None: self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label, train_subset_num_batches) - if isinstance(self.state.device, DeviceTPU): + if self.state.device.dist_backend == 'xla': self.state.train_dataloader = pl.MpDeviceLoader(self.state.dataloader, xm.xla_device()) else: self.state.train_dataloader = self.state.dataloader @@ -2342,10 +2342,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]: if use_grad_scaling: self.state.scaler.step(optimizer) else: - if isinstance(self.state.device, DeviceTPU): - xm.optimizer_step(optimizer, barrier=True) - else: - optimizer.step() + optimizer.step() except RuntimeError as e: if self.state.auto_microbatching and _is_cuda_oom(e): log.debug((f"Rank {dist.get_global_rank()} OOM'd.")) @@ -3162,7 +3159,7 @@ def _use_closures(self) -> bool: if self.state.deepspeed_enabled: return False - if isinstance(self.state.device, DeviceTPU): + if self.state.device.dist_backend == 'xla': return False if self.state.precision != Precision.AMP_FP16: diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 67ed33cdd3..3c9a69eff6 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -10,7 +10,7 @@ save_checkpoint) from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report, get_composer_env_dict, print_env) -from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed +from composer.utils.device import get_device, is_hpu_installed, is_xla_installed from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, create_symlink_file, ensure_folder_has_no_conflicting_files, @@ -78,7 +78,7 @@ 'retry', 'model_eval_mode', 'get_device', - 'is_tpu_installed', + 'is_xla_installed', 'is_hpu_installed', 'ExportFormat', 'Transform', diff --git a/composer/utils/device.py b/composer/utils/device.py index 6ad303f6c9..73b4e57ab4 100644 --- a/composer/utils/device.py +++ b/composer/utils/device.py @@ -10,7 +10,9 @@ if TYPE_CHECKING: from composer.devices import Device -__all__ = ['get_device', 'is_tpu_installed', 'is_hpu_installed'] +__all__ = ['get_device', 'is_hpu_installed', 'is_xla_installed'] + +_is_xla_installed = None def get_device(device: Optional[Union[str, 'Device']]) -> 'Device': @@ -25,7 +27,7 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device': Device. If no argument is passed, returns :class:`.DeviceGPU` if available, or :class:`.DeviceCPU` if no GPU is available. """ - from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceTPU + from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceNeuron, DeviceTPU if not device: device = DeviceGPU() if torch.cuda.is_available() else DeviceCPU() @@ -37,11 +39,17 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device': elif device.lower() == 'mps': device = DeviceMPS() elif device.lower() == 'tpu': - if not is_tpu_installed(): + if not is_xla_installed(): raise ImportError( 'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla' ) device = DeviceTPU() + elif device.lower() == 'neuron': + if not is_xla_installed(): + raise ImportError( + 'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla' + ) + device = DeviceNeuron() elif device.lower() == 'hpu': if not is_hpu_installed(): raise ImportError('Unable to import habana-torch-plugin.') @@ -51,17 +59,22 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device': return device -def is_tpu_installed() -> bool: +def is_xla_installed() -> bool: """Determines whether the module needed for training on TPUs—torch_xla—is installed. Returns: bool: Whether torch_xla is installed. """ + global _is_xla_installed + if _is_xla_installed: + return _is_xla_installed try: import torch_xla del torch_xla + _is_xla_installed = True return True except ModuleNotFoundError: + _is_xla_installed = False return False diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 5b8dd5df68..e732d4532a 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -47,9 +47,9 @@ import torch.utils.data from packaging import version -from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed +from composer.utils.device import get_device, is_hpu_installed, is_xla_installed -if is_tpu_installed(): +if is_xla_installed(): import torch_xla if TYPE_CHECKING: