Skip to content

Commit

Permalink
add device check when import ipex
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 committed Aug 9, 2024
1 parent 61350c5 commit df9a440
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import torch
import intel_extension_for_pytorch
from deepspeed.accelerator.real_accelerator import get_accelerator

if get_accelerator().device_name() == 'xpu':
import intel_extension_for_pytorch
if get_accelerator().device_name() == 'cuda':
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from apex.normalization import MixedFusedRMSNorm as RMSNorm
Expand Down

0 comments on commit df9a440

Please sign in to comment.