Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Distributed LogProb for GRPO Training #6247

Merged
merged 22 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def setup(self) -> None:
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
if self.plugin_config.get("tp_size", 1) > 1:
plugin_config["parallel_output"] = False
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down
20 changes: 17 additions & 3 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from colossalai.shardformer.layer.loss import dist_log_prob


def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
Expand Down Expand Up @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)


def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size


Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]


Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import (
Expand All @@ -28,6 +28,8 @@
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"dist_log_prob_1d",
"dist_log_prob",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
Expand Down
150 changes: 149 additions & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax

from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig

from .utils import is_share_sp_tp

__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
__all__ = [
"DistCrossEntropy",
"cross_entropy_1d",
"dist_cross_entropy",
"DistLogProb",
"dist_log_prob_1d",
"dist_log_prob",
]

_IGNORE_IDX = -100

Expand Down Expand Up @@ -137,6 +145,98 @@ def backward(ctx, grad_output):
return grad_logits, None, None, None, None, None, None


class DistLogProb(Function):
r"""
Overwrite the forward and backward function to calculate the log prob before gather

Args:
Function (:class:`torch.autograd.Function`): default
"""

@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):

##################
# Step1:Find the global maximum value of logits
##################
logits_max = torch.max(vocab_logits, dim=-1)[0]
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)

##################
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
# For accleration, we overlap Step 2 and Step 3
##################
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size is None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# down and up threshold for local logits
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
# mask
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
handle.wait()

##################
# Step3:Calculate global summation exp logits
##################
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
exp_logits = torch.exp(vocab_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

##################
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
##################
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)

ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
ctx.dtype = dtype
return log_probs

@staticmethod
def backward(ctx, grad_output):
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
##################
# Step1:Find the global sofmax value
##################
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)

##################
# Step2:Update softmax value based on local target index
##################
partion_vocab_size = softmax_logits.shape[-1]
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update

##################
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
##################
grad_logits = -softmax_logits.mul_(grad_output)
return grad_logits, None, None, None, None, None, None


def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
Expand All @@ -149,6 +249,16 @@ def cross_entropy_1d(
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)


def dist_log_prob_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)


def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
Expand Down Expand Up @@ -243,3 +353,41 @@ def dist_cross_entropy(
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss


def dist_log_prob(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor:
"""
Helper to compute log prob for most shardformer models supporting PP, TP.
"""
# Split labels if not gather output
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism

# TODO:support sp
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"

# Flatten the tokens
if is_tp and parallel_output:
log_prob = dist_log_prob_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=vocab_size,
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))

return log_prob
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,6 @@ def forward(
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,12 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
Expand Down
52 changes: 52 additions & 0 deletions tests/test_shardformer/test_layer/test_dist_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
from coati.distributed.utils import log_probs_from_logits

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer import dist_log_prob_1d
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
)


def check_dist_log_prob(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")

# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)).cuda()

logprob = log_probs_from_logits(pred, labels)

pred.retain_grad()
logprob.mean().backward()

dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_pred.requires_grad = True
dist_logprob = dist_log_prob_1d(dist_pred, labels)

dist_pred.retain_grad()
dist_logprob.squeeze(-1).mean().backward()

assert torch.allclose(
logprob, dist_logprob.squeeze(-1), atol=1e-5
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"

pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
assert torch.allclose(
pred_grad_partial, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_log_prob():
spawn(check_dist_log_prob, 2)


if __name__ == "__main__":
test_dist_log_prob()