Skip to content

Commit 3af2e25

Browse files
acquire device when required (#420)
* Update module.py * Update preprocess_data.py * add copyrights * add copyrights * Update tokenizer.py * add copyrights
1 parent 13f2673 commit 3af2e25

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

megatron/model/module.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
12
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
23

34
"""Megatron Module"""
@@ -10,10 +11,9 @@
1011
from megatron.core import mpu, tensor_parallel
1112

1213

13-
_FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype]
14-
_HALF_TYPES = [get_accelerator().HalfTensor(0).dtype]
15-
_BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype]
16-
14+
_FLOAT_TYPES = None
15+
_HALF_TYPES = None
16+
_BF16_TYPES = None
1717

1818

1919
def param_is_not_shared(param):
@@ -131,6 +131,9 @@ def conversion_helper(val, conversion):
131131

132132
def fp32_to_float16(val, float16_convertor):
133133
"""Convert fp32 `val` to fp16/bf16"""
134+
global _FLOAT_TYPES
135+
if _FLOAT_TYPES is None:
136+
_FLOAT_TYPES = [get_accelerator().FloatTensor(0).dtype]
134137
def half_conversion(val):
135138
val_typecheck = val
136139
if isinstance(val_typecheck, (Parameter, Variable)):
@@ -143,6 +146,11 @@ def half_conversion(val):
143146

144147
def float16_to_fp32(val):
145148
"""Convert fp16/bf16 `val` to fp32"""
149+
global _HALF_TYPES, _BF16_TYPES
150+
if _HALF_TYPES is None:
151+
_HALF_TYPES = [get_accelerator().HalfTensor(0).dtype]
152+
if _BF16_TYPES is None:
153+
_BF16_TYPES = [get_accelerator().BFloat16Tensor(0).dtype]
146154
def float_conversion(val):
147155
val_typecheck = val
148156
if isinstance(val_typecheck, (Parameter, Variable)):

tools/preprocess_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def get_args():
235235
print("Are you sure you don't want to split sentences?")
236236

237237
# some default/dummy values for the tokenizer
238-
args.rank = 1
238+
args.rank = 0
239239
args.make_vocab_size_divisible_by = 128
240240
args.tensor_model_parallel_size = 1
241241
args.vocab_extra_ids = 0

0 commit comments

Comments
 (0)