Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Domino for Llama3 #959

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions training/DeepSpeed-Domino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
pip install -r requirements.txt
```

## Prepare the Dataset
Follow the instructions from [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset.

## Execute Domino Training
Expand Down Expand Up @@ -38,16 +37,16 @@ The output should look like this:

```
training ...
iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152
iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988
iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736
iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979
iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377
iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254
iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691
iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165
iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684
iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998
iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152
iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988
iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736
iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979
iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377
iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254
iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691
iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165
iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684
iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998
[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully.
[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully.
[2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully.
Expand Down
24 changes: 24 additions & 0 deletions training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,25 @@ def parse_args():
parser.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope'],
help='Position embedding type.')
parser.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
parser.add_argument('--rotary-base', type=int, default=10000,
help='Base to use for rotary positional embeddings, default 10000')
parser.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
parser.add_argument('--rotary-interleaved', action='store_true',
help='Use interleaved rotary embedding.')
parser.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.')
parser.add_argument('--use-rope-scaling', action='store_true',
help='Apply rope scaling as used in llama3.1')
parser.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
parser.add_argument('--group-query-attention', action='store_true',
help='Use group-query attention.')
parser.add_argument('--num-query-groups', type=int, default=1)
parser.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
parser.add_argument('--attention-dropout', type=float, default=0.1,
Expand Down Expand Up @@ -180,8 +195,11 @@ def parse_args():
'GPT2BPETokenizer',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'NullTokenizer'],
help='What type of tokenizer to use.')
parser.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
Expand Down Expand Up @@ -343,6 +361,12 @@ class TransformerConfig():
gated_linear_unit: bool = False
activation_func: Callable = F.gelu
bias_gelu_fusion = False
kv_channels: int = None
rotary_interleaved: bool = False
normalization: str = 'LayerNorm'
group_query_attention: bool = False
num_query_groups: int = 1
seq_length: int = 2048

# initialization
init_method: Callable = None
Expand Down
65 changes: 30 additions & 35 deletions training/DeepSpeed-Domino/domino/language_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# This file is adapted from language_model.py in Megatron-LM

from typing import Literal, Optional

import torch
from torch import einsum, nn
from domino.arguments import get_args
Expand All @@ -14,6 +16,9 @@
from domino.tensor_parallel.partition import _initialize_affine_weight_gpu, set_tensor_model_parallel_attributes
from domino.tensor_parallel.partition import ColumnParallelLinear, RowParallelLinearNoComm

from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.model.utils import get_norm

from deepspeed.runtime.domino.transformer import DominoTransformer

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
Expand Down Expand Up @@ -45,12 +50,18 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(config, num_tokentypes,
encoder_attn_mask_type,
pre_process=True, post_process=True):
args = get_args()
language_model = TransformerLanguageModel(
config,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
pre_process=pre_process,
post_process=post_process
post_process=post_process,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
seq_len_interpolation_factor = args.rotary_seq_len_interpolation_factor,
)

return language_model
Expand Down Expand Up @@ -85,38 +96,18 @@ def forward(self, input_ids, position_ids):
return combined_embeds


class RotaryEmbedding(nn.Module):
def __init__(self, dim, seq_len_interpolation_factor=None):
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)

def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb[:, None, None, :]

# def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# state_dict.pop(f'{prefix}inv_freq', None)
# return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class TransformerLanguageModel(DominoModule):
def __init__(self,
config,
encoder_attn_mask_type,
num_tokentypes=0,
pre_process=True,
post_process=True):

post_process=True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
seq_len_interpolation_factor: Optional[float] = None,):
args = get_args()
super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=True)

Expand All @@ -127,6 +118,11 @@ def __init__(self,
self.init_method = config.init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.encoder_hidden_state = None
self.position_embedding_type = position_embedding_type
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.seq_length = config.seq_length

if self.pre_process:
self.embedding = Embedding(self.hidden_size,
Expand All @@ -138,19 +134,18 @@ def __init__(self,
self.use_rotary_position_embeddings = \
args.position_embedding_type == 'rope'
if self.use_rotary_position_embeddings:
self.seq_length = args.seq_length
rotary_dim = args.hidden_size // args.num_attention_heads \
if args.kv_channels is None else args.kv_channels
if args.rotary_percent < 1.0:
rotary_dim = int(rotary_dim * args.rotary_percent)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
kv_channels=config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
)

self.encoder = DominoTransformer(
config, ModelType.encoder_or_decoder, mpu,
fused_layer_norm, _initialize_affine_weight_gpu,
get_norm, _initialize_affine_weight_gpu,
ColumnParallelLinear, RowParallelLinearNoComm, apply_rotary_pos_emb,
bias_dropout_add_fused_train, bias_dropout_add_fused_inference,
self_attn_mask_type=self.encoder_attn_mask_type,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import json
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
import numpy
class MegatronTokenizer(ABC):
"""Abstract class for tokenizer
Absent a config or class-specific tracking of which objects are uniquely identifying, we must
include all key word arguments as unique identifiers
Args:
tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes
tokenizer_options (Dict[str, Any]): All tokenizer options
"""
def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any):
self.unique_identifiers = OrderedDict()
self.unique_identifiers["class"] = type(self).__name__
self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths)
for option in tokenizer_options:
self.unique_identifiers[option] = str(tokenizer_options[option])
self.unique_description = json.dumps(self.unique_identifiers, indent=4)
super().__init__()
@abstractmethod
def tokenize(self, text: str) -> numpy.ndarray:
"""Convert text to embedding ids
Args:
text (str): The text to convert
Returns:
numpy.ndarray: The converted embedding ids
"""
pass
def detokenize(self, ids: numpy.ndarray) -> str:
"""Convert embedding ids to text
Args:
ids (numpy.ndarray): The ids to convert
Returns:
str: The converted text
Raises:
NotImplementedError: Non-abstract, optional method
"""
raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__))
def offsets(self, ids: list[int], text: str) -> list[int]:
"""Convert embedding ids to text offsets
Args:
ids (list[int]): The ids to convert
text (str): The text to convert
Returns:
list[int]: The converted offsets
Raises:
NotImplementedError: Non-abstract, optional method
"""
raise NotImplementedError("{} has no method 'offsets'".format(type(self).__name__))
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token"""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token"""
pass
@property
@abstractmethod
def vocab_size(self):
"""The vocabulary size"""
pass
@property
def cls(self):
"""The CLS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__))
@property
def sep(self):
"""The SEP token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__))
@property
def pad(self):
"""The PAD token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__))
@property
def eod(self):
"""The EOD token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__))
@property
def bos(self):
"""The BOS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__))
@property
def eos(self):
"""The EOS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__))
@property
def mask(self):
"""The MASK token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__))
Loading
Loading