Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for neuron devices #3049

Merged
merged 10 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from composer.core import Callback, State
from composer.loggers import Logger
from composer.models.base import ComposerModel
from composer.utils import dist
from composer.utils import dist, is_xla_installed

if is_xla_installed():
import torch_xla.core.xla_model as xm

__all__ = ['SpeedMonitor']

Expand Down Expand Up @@ -83,30 +86,47 @@
'int8': 130e12,
'int4': 260e12,
},
'trn1': {
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
'fp32': 190e12 / 2,
'tf32': 190e12 / 2,
'fp16': 190e12 / 2,
'amp_fp16': 190e12 / 2,
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
'bf16': 190e12 / 2,
'amp_bf16': 190e12 / 2,
'int8': 380e12 / 2,
}
}


def get_gpu_flops_available(state: State):
gpu_flops_available = None

# Return 0 if no CUDA device (e.g., when running with CPU only)
if not torch.cuda.is_available():
if torch.cuda.is_available():
# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
device_name = torch.cuda.get_device_name().lower()
if 'h100' in device_name and 'hbm3' in device_name:
device_name = 'h100-sxm'
elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name):
device_name = 'h100-pcie'
elif 'a100' in device_name:
device_name = 'a100'
elif 'v100-sxm' in device_name:
device_name = 'v100-sxm'
elif 'v100-pcie' in device_name:
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'
elif is_xla_installed():
if xm.xla_device_hw(xm.xla_device()) == 'NEURON':
device_name = 'trn1'
else:
# For TPU return 0
return 0
else:
# When running on CPU, return 0 without warning
return 0

# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
device_name = torch.cuda.get_device_name().lower()
if 'h100' in device_name and 'hbm3' in device_name:
device_name = 'h100-sxm'
elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name):
device_name = 'h100-pcie'
elif 'a100' in device_name:
device_name = 'a100'
elif 'v100-sxm' in device_name:
device_name = 'v100-sxm'
elif 'v100-pcie' in device_name:
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'

if device_name in GPU_AVAILABLE_FLOPS and state.precision.value in GPU_AVAILABLE_FLOPS[device_name]:
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
Expand Down
15 changes: 11 additions & 4 deletions composer/core/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from composer.utils import StringEnum
from composer.utils import StringEnum, is_xla_installed

try:
import transformer_engine.pytorch as te
Expand All @@ -36,7 +36,6 @@ class Precision(StringEnum):
AMP_BF16 = 'amp_bf16'
AMP_FP8 = 'amp_fp8'


@contextlib.contextmanager
def get_precision_context(precision: Union[str, Precision],
precision_config: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]:
Expand All @@ -57,14 +56,22 @@ def get_precision_context(precision: Union[str, Precision],
yield
elif precision == Precision.AMP_FP16:
# Retain compatibility with PyTorch < 1.10
with torch.cuda.amp.autocast(True):
if torch.cuda.is_available():
with torch.cuda.amp.autocast(True):
yield
elif is_xla_installed():
with torch.autocast('xla', dtype=torch.float16):
yield
else:
yield
elif precision == Precision.AMP_BF16:
if torch.cuda.is_available():
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
yield
elif is_xla_installed():
with torch.autocast('xla', dtype=torch.bfloat16):
yield
else:
os.environ['XLA_USE_BF16'] = '1'
yield
elif precision == Precision.AMP_FP8:
if te_installed and torch.cuda.get_device_capability() >= (8, 9):
Expand Down
3 changes: 2 additions & 1 deletion composer/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from composer.devices.device_hpu import DeviceHPU
from composer.devices.device_mps import DeviceMPS
from composer.devices.device_tpu import DeviceTPU
from composer.devices.device_neuron import DeviceNeuron

__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceTPU', 'DeviceHPU']
__all__ = ['Device', 'DeviceCPU', 'DeviceGPU', 'DeviceMPS', 'DeviceNeuron', 'DeviceTPU', 'DeviceHPU']
48 changes: 48 additions & 0 deletions composer/devices/device_neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""The Neuron device used for training."""

from __future__ import annotations

import logging
from typing import Any, Dict, TypeVar
import os
import torch

from composer.devices.device import Device

logger = logging.getLogger(__name__)

__all__ = ['DeviceNeuron']

T_nnModule = TypeVar('T_nnModule', bound=torch.nn.Module)


class DeviceNeuron(Device):
"""An extension of :class:`~composer.devices.device.Device` for Neuron devices (Trn, Inf).
When running on Trn, we automatically set `export PJRT_DEVICE=NEURON`.
More details.
"""

name = 'neuron'
dist_backend = 'xla'

def __init__(self):
import torch_xla.core.xla_model as xm
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
bfontain marked this conversation as resolved.
Show resolved Hide resolved
os.environ['PJRT_DEVICE']='NEURON'
self._device = xm.xla_device()

def module_to_device(self, module: T_nnModule) -> T_nnModule:
return module.to(self._device)

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> Dict[str, Any]:
return {}

def load_state_dict(self, state: Dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('Neuron device has no state.')
15 changes: 6 additions & 9 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator,
ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceNeuron, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
Expand All @@ -56,13 +56,13 @@
set_fsdp_default)
from composer.utils import (ExportFormat, MissingConditionalImportError, ObjectStore, Transform, checkpoint, dist,
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
get_composer_env_dict, get_device, get_file, is_xla_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
model_eval_mode, parse_uri, partial_format, reproducibility)
from composer.utils.misc import is_model_deepspeed
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

if is_tpu_installed():
if is_xla_installed():
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

Expand Down Expand Up @@ -1320,7 +1320,7 @@ def __init__(
if self._train_data_spec is not None:
self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label,
train_subset_num_batches)
if isinstance(self.state.device, DeviceTPU):
if self.state.device.dist_backend == 'xla':
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
self.state.train_dataloader = pl.MpDeviceLoader(self.state.dataloader, xm.xla_device())
else:
self.state.train_dataloader = self.state.dataloader
Expand Down Expand Up @@ -2342,10 +2342,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
if use_grad_scaling:
self.state.scaler.step(optimizer)
else:
if isinstance(self.state.device, DeviceTPU):
xm.optimizer_step(optimizer, barrier=True)
else:
optimizer.step()
optimizer.step()
except RuntimeError as e:
if self.state.auto_microbatching and _is_cuda_oom(e):
log.debug((f"Rank {dist.get_global_rank()} OOM'd."))
Expand Down Expand Up @@ -3162,7 +3159,7 @@ def _use_closures(self) -> bool:
if self.state.deepspeed_enabled:
return False

if isinstance(self.state.device, DeviceTPU):
if self.state.device.dist_backend == 'xla':
return False

if self.state.precision != Precision.AMP_FP16:
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
save_checkpoint)
from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report,
get_composer_env_dict, print_env)
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_xla_installed
from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE,
create_symlink_file, ensure_folder_has_no_conflicting_files,
Expand Down Expand Up @@ -78,7 +78,7 @@
'retry',
'model_eval_mode',
'get_device',
'is_tpu_installed',
'is_xla_installed',
'is_hpu_installed',
'ExportFormat',
'Transform',
Expand Down
20 changes: 16 additions & 4 deletions composer/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
if TYPE_CHECKING:
from composer.devices import Device

__all__ = ['get_device', 'is_tpu_installed', 'is_hpu_installed']
__all__ = ['get_device', 'is_xla_installed', 'is_hpu_installed']

_is_xla_installed = None

def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
"""Takes string or Device and returns the corresponding :class:`~composer.devices.Device`.
Expand All @@ -25,7 +26,7 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
Device. If no argument is passed, returns :class:`.DeviceGPU` if available,
or :class:`.DeviceCPU` if no GPU is available.
"""
from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceTPU
from composer.devices import DeviceCPU, DeviceGPU, DeviceHPU, DeviceMPS, DeviceTPU, DeviceNeuron

if not device:
device = DeviceGPU() if torch.cuda.is_available() else DeviceCPU()
Expand All @@ -37,11 +38,17 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
elif device.lower() == 'mps':
device = DeviceMPS()
elif device.lower() == 'tpu':
if not is_tpu_installed():
if not is_xla_installed():
raise ImportError(
'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla'
)
device = DeviceTPU()
elif device.lower() == 'neuron':
if not is_xla_installed():
raise ImportError(
'Unable to import torch_xla. Please follow installation instructions at https://github.com/pytorch/xla'
)
device = DeviceNeuron()
elif device.lower() == 'hpu':
if not is_hpu_installed():
raise ImportError('Unable to import habana-torch-plugin.')
Expand All @@ -51,17 +58,22 @@ def get_device(device: Optional[Union[str, 'Device']]) -> 'Device':
return device


def is_tpu_installed() -> bool:
def is_xla_installed() -> bool:
"""Determines whether the module needed for training on TPUs—torch_xla—is installed.

Returns:
bool: Whether torch_xla is installed.
"""
global _is_xla_installed
if _is_xla_installed:
return _is_xla_installed
try:
import torch_xla
del torch_xla
_is_xla_installed = True
return True
except ModuleNotFoundError:
_is_xla_installed = False
return False


Expand Down
4 changes: 2 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_xla_installed

if is_tpu_installed():
if is_xla_installed():
import torch_xla

if TYPE_CHECKING:
Expand Down
Loading