Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] grpo latest support npu #6242

Open
wants to merge 33 commits into
base: grpo-latest-npu
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ef1c1cb
fix inference rebatching bug
YeAnbang Feb 20, 2025
01f84de
fix num_train_step update
YeAnbang Feb 20, 2025
bc66524
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
ccc512a
[misc] update torch version (#6206)
ver217 Feb 24, 2025
bcb5b60
[hotfix] fix lora load (#6231)
ver217 Mar 1, 2025
a7e3bec
[release] update version (#6236)
ver217 Mar 3, 2025
265d430
[feat] add ops test to adapt npu
duanjunwen Mar 11, 2025
e61bb0a
[feat] test loss func & assert close
duanjunwen Mar 11, 2025
704866a
detach
Mar 11, 2025
bc6e14a
[feat] support compare tools on npu
duanjunwen Mar 11, 2025
6930c7c
[fix] fix qwen policy, now use gather output as logits
duanjunwen Mar 12, 2025
4b24a03
[fix] fix qwen lmhead, now gather output for logints
duanjunwen Mar 12, 2025
05ca507
[feat] fix qwen Linear_Col --> VocalHead
duanjunwen Mar 13, 2025
2305f93
[fix] fix
duanjunwen Mar 13, 2025
03ce3c5
[fix] fix qwen VocabParallelLMHead1D and gather output
duanjunwen Mar 13, 2025
131eece
fix tp bug
Mar 13, 2025
afddfde
fix consumer
Mar 13, 2025
b835d1b
fix tp bug
Mar 13, 2025
137ec17
fix consumer
Mar 13, 2025
a9cf3aa
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 13, 2025
4702d57
convert to 8 generation
Mar 13, 2025
45ac6c6
print results
Mar 13, 2025
57b49da
setup update
Mar 13, 2025
bc0171d
fix transformers backend
YeAnbang Mar 14, 2025
7b3c310
Merge branch 'hpcaitech:grpo-latest' into grpo-latest
duanjunwen Mar 17, 2025
dcf3f9b
[fix] fix qwen VocabParallelLMHead1D and gather output
duanjunwen Mar 13, 2025
d90bf57
Merge branch 'grpo-latest' of github.com:duanjunwen/ColossalAI into g…
duanjunwen Mar 18, 2025
a53d4cd
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 18, 2025
7795d4c
[Feature] Support Distributed LogProb for GRPO Training (#6247)
duanjunwen Mar 18, 2025
283a479
Merge branch 'hpcaitech:grpo-latest' into grpo-latest
duanjunwen Mar 19, 2025
4712ecc
Merge branch 'grpo-latest' of github.com:duanjunwen/ColossalAI into g…
duanjunwen Mar 19, 2025
d3fd485
Merge branch 'grpo-latest' into grpo-latest-npu
duanjunwen Mar 19, 2025
3f4818c
[feat] support hybrid test
duanjunwen Mar 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .compatibility
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
2.2.2-12.1.0
2.3.0-12.1.0
2.4.0-12.4.1
2.5.1-12.4.1
4 changes: 2 additions & 2 deletions .cuda_ext.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"build": [
{
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
"cuda_image": "hpcaitech/cuda-conda:12.1"
},
{
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
"cuda_image": "hpcaitech/cuda-conda:12.4"
}
]
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
7 changes: 4 additions & 3 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
tp_size=2,
pp_size=1,
precision="bf16",
zero_stage=1,
zero_stage=2,
parallel_output=False,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
Expand Down
11 changes: 8 additions & 3 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,25 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:

need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
# ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
ctx = nullcontext()
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = 8

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
out = self.model.generate(
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]
# get log probs
Expand All @@ -76,10 +86,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
response_idx[:, 0] = input_len
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1

if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
Expand All @@ -91,7 +104,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}

data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}

if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
Expand Down Expand Up @@ -154,7 +175,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=4,
n=8,
)

def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
Expand Down
11 changes: 3 additions & 8 deletions applications/ColossalChat/coati/distributed/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,10 @@ def forward(
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
ratio = (log_probs - log_probs.detach()).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()

# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
skip = True

ratio = ratio_.clamp(0.0, 10.0)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
Expand All @@ -44,4 +39,4 @@ def forward(
else:
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
return loss, skip, ratio.max()
6 changes: 5 additions & 1 deletion applications/ColossalChat/coati/distributed/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def __init__(

@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
return self.model.generate(input_ids, attention_mask, **kwargs)
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))

return rollouts

def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
20 changes: 17 additions & 3 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from colossalai.shardformer.layer.loss import dist_log_prob


def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
Expand Down Expand Up @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)


def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size


Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def make_experience(
num_actions = 0

for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def _criterion(outputs, inputs):
self.accumulative_meter.get("accuracy"),
global_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
self.num_train_step += 1

if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint
Expand Down
4 changes: 2 additions & 2 deletions applications/ColossalChat/coati/trainer/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def _training_step(self, experience: Experience):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
Expand Down Expand Up @@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
self.temperature_annealing_scheduler.step_forward()

# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
Expand Down Expand Up @@ -327,6 +326,7 @@ def _training_step(self, experience: Experience):
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
self.num_train_step += 1

def _learn(self, update_step: int):
"""
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _train(self, epoch: int):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1

step_bar.close()

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _train(self, epoch: int):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1

step_bar.close()

Expand Down
4 changes: 2 additions & 2 deletions applications/ColossalChat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _training_step(self, experience: Experience):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
self.critic.train()
num_actions = experience.action_log_probs.size(1)
Expand Down Expand Up @@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
self.critic_scheduler.step()

# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
Expand Down Expand Up @@ -336,6 +335,7 @@ def _training_step(self, experience: Experience):
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
self.accumulative_meter.reset()
self.num_train_step += 1

def _learn(self, update_step: int):
"""
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _train(self, epoch):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()

def _eval(self, epoch):
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def _train(self, epoch: int):
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.num_train_step += 1

# Save checkpoint
if (
Expand Down
Loading
Loading