Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss disagreement between TP=1 and TP=2 #631

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch
import torch.distributed
from megatron.core import tensor_parallel
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
Expand Down Expand Up @@ -178,16 +177,21 @@ def __init__(
config,
)

self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=True,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
# TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete
# self.output_layer = tensor_parallel.ColumnParallelLinear(
# config.hidden_size,
# self.vocab_size,
# config=config,
# init_method=config.init_method,
# is_expert=False,
# bias=True,
# skip_bias_add=False,
# gather_output=not self.parallel_output,
# skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
# embedding_activation_buffer=self.embedding_activation_buffer,
# grad_output_buffer=self.grad_output_buffer,
# )
self.output_layer = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably has bugs with bias getting out of sync between TP ranks during training FYI. See https://jirasw.nvidia.com/browse/BIONEMO-668 and https://jirasw.nvidia.com/browse/BIONEMO-666 and https://nvidia.slack.com/archives/C074Z808N05/p1737508003987919 and https://nvidia.slack.com/archives/C0434FDLPQV/p1733963545314469

Also if your concern is when you do TP=2 that the logit dim is 1/2 that may be because columnparallellinear splits along the logit vocab dimension, and ideally vocab parallel cross entropy knows how to reduce across this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Only temporarily placed torch.nn.Linear to debug. Will revert back to ColumnParallelLinear after so.

Copy link
Collaborator Author

@sichu2023 sichu2023 Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am less concerned about "1/2 logits dim" (128 dim) but more concerned about torch.nn.Linear giving 256 dim on TP=2. 128 dim should be the correct dim (33 + padding).


self.binary_head = None
if self.add_binary_head:
Expand Down Expand Up @@ -362,3 +366,6 @@ class ESM2Config(ESM2GenericConfig, iom.IOMixinWithGettersSetters):
model_cls: Type[ESM2Model] = ESM2Model
num_layers: int = 33 # 650M
hidden_size: int = 1280 # 650M
output_layer_init_method: Callable = (
torch.nn.init.zeros_
) # TODO make param init reproducible; remove after debugging
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -211,7 +210,6 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,14 @@ def forward(
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
output_weight = self.shared_embedding_or_output_weight() # noqa: F841

hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
if not self.skip_logits:
# TODO add , runtime_gather_output=runtime_gather_output once supported in ColumnParallelLinear
logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
# TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete
# logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
logits = self.output_layer(hidden_states_after_lm_head)
else:
logits = None

Expand Down
2 changes: 2 additions & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def forward(
forward_out_report = {}

# NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
# TODO: logits always match on tp=1 and tp=2, and among tp ranks
# however, unreduced_token_loss does not match between tp=1 and tp=2. only match among tp ranks at tp=2
unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"]) # [b s]

# TODO(@jstjohn) also handle different output keys, like the sequence loss.
Expand Down
Loading