Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add support for seq split in Domino #961

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def parse_args():
help='Report loss and timing interval.')
parser.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
parser.add_argument('--input-split-dim', type=str, default='batch',
help='Dimension for input split.')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension for input split. ['batch', 'seq']


args = parser.parse_args()

Expand Down Expand Up @@ -355,6 +357,8 @@ class TransformerConfig():
no_sync_func: Callable = None
# grad_sync_func: Callable = None
# param_sync_func: Callable = None

input_split_dim: str = 'batch'

def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
Expand Down Expand Up @@ -396,5 +400,6 @@ def core_transformer_config_from_args(args):
kw_args['init_method'] = args.init_method
kw_args['output_layer_init_method'] = args.init_method
kw_args['params_dtype'] = args.params_dtype
kw_args['input_split_dim'] = args.input_split_dim

return TransformerConfig(**kw_args)
22 changes: 18 additions & 4 deletions training/DeepSpeed-Domino/domino/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
self.init_method = config.init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.encoder_hidden_state = None
self.input_split_dim = config.input_split_dim

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one possible thing we need to verify and check is if split dim is seq, can we still initialize the same rope embedding, which is a very popular position embedding function. because for rope:

        if self.use_rotary_position_embeddings:
            self.seq_length = args.seq_length

our self.seq_length may not be correct ( ours might be half of original seq_length).


if self.pre_process:
self.embedding = Embedding(self.hidden_size,
Expand Down Expand Up @@ -177,17 +178,30 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,

encoder_out_size = encoder_input.shape
p_batch_size = encoder_out_size[1] // 2
p_seq_size = encoder_out_size[0] // 2
dtype = encoder_input.dtype
encoder_output_t = torch.empty(encoder_out_size, dtype=dtype, device=torch.cuda.current_device())
intra_partitions = 2
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1)
if self.input_split_dim == 'batch':
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1)
elif self.input_split_dim == 'seq':
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=0)
else:
raise NotImplementedError
encoder_outputs = self.encoder(
encoder_inputs,
enc_attn_mask,
rotary_pos_emb=rotary_pos_emb)
encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0]
encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1]

if self.input_split_dim == 'batch':
encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0]
encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1]
elif self.input_split_dim == 'seq':
encoder_output_t[0:p_seq_size, :, :] = encoder_outputs[0]
encoder_output_t[p_seq_size:2*p_seq_size, :, :] = encoder_outputs[1]
else:
raise NotImplementedError

encoder_output = encoder_output_t

return encoder_output