1
+ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
1
2
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
3
4
"""Megatron Module"""
10
11
from megatron .core import mpu , tensor_parallel
11
12
12
13
13
- _FLOAT_TYPES = [get_accelerator ().FloatTensor (0 ).dtype ]
14
- _HALF_TYPES = [get_accelerator ().HalfTensor (0 ).dtype ]
15
- _BF16_TYPES = [get_accelerator ().BFloat16Tensor (0 ).dtype ]
16
-
14
+ _FLOAT_TYPES = None
15
+ _HALF_TYPES = None
16
+ _BF16_TYPES = None
17
17
18
18
19
19
def param_is_not_shared (param ):
@@ -131,6 +131,9 @@ def conversion_helper(val, conversion):
131
131
132
132
def fp32_to_float16 (val , float16_convertor ):
133
133
"""Convert fp32 `val` to fp16/bf16"""
134
+ global _FLOAT_TYPES
135
+ if _FLOAT_TYPES is None :
136
+ _FLOAT_TYPES = [get_accelerator ().FloatTensor (0 ).dtype ]
134
137
def half_conversion (val ):
135
138
val_typecheck = val
136
139
if isinstance (val_typecheck , (Parameter , Variable )):
@@ -143,6 +146,11 @@ def half_conversion(val):
143
146
144
147
def float16_to_fp32 (val ):
145
148
"""Convert fp16/bf16 `val` to fp32"""
149
+ global _HALF_TYPES , _BF16_TYPES
150
+ if _HALF_TYPES is None :
151
+ _HALF_TYPES = [get_accelerator ().HalfTensor (0 ).dtype ]
152
+ if _BF16_TYPES is None :
153
+ _BF16_TYPES = [get_accelerator ().BFloat16Tensor (0 ).dtype ]
146
154
def float_conversion (val ):
147
155
val_typecheck = val
148
156
if isinstance (val_typecheck , (Parameter , Variable )):
0 commit comments