Skip to content

Commit

Permalink
acquire device when required (#420)
Browse files Browse the repository at this point in the history
* Update module.py

* Update preprocess_data.py

* add copyrights

* add copyrights

* Update tokenizer.py

* add copyrights
  • Loading branch information
polisettyvarma authored Jul 19, 2024
1 parent 13f2673 commit 3af2e25
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions megatron/model/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Megatron Module"""
Expand All @@ -10,10 +11,9 @@
from megatron.core import mpu, tensor_parallel


_FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype]
_HALF_TYPES = [get_accelerator().HalfTensor(0).dtype]
_BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype]

_FLOAT_TYPES = None
_HALF_TYPES = None
_BF16_TYPES = None


def param_is_not_shared(param):
Expand Down Expand Up @@ -131,6 +131,9 @@ def conversion_helper(val, conversion):

def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
global _FLOAT_TYPES
if _FLOAT_TYPES is None:
_FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype]
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
Expand All @@ -143,6 +146,11 @@ def half_conversion(val):

def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
global _HALF_TYPES, _BF16_TYPES
if _HALF_TYPES is None:
_HALF_TYPES = [get_accelerator().HalfTensor(0).dtype]
if _BF16_TYPES is None:
_BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype]
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
Expand Down
2 changes: 1 addition & 1 deletion tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_args():
print("Are you sure you don't want to split sentences?")

# some default/dummy values for the tokenizer
args.rank = 1
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
Expand Down

0 comments on commit 3af2e25

Please sign in to comment.