Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add rlhf pipeline. #748

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'chat':
cfg.epoch_per_collect = 1
cfg.batch_size = 1
cfg.learning_rate = 5e-7
cfg.answers_per_question = 3
cfg.kl_penalty_weight = 0.1
cfg.ppo_param_init = False
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
Expand Down Expand Up @@ -315,6 +322,16 @@ def get_instance_env(env_id: str) -> BaseEnv:
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
return ChatEnv(
batch_size=1,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=128,
maxlen_res=128,
)
else:
raise KeyError("not supported env type: {}".format(env_id))

Expand Down
25 changes: 19 additions & 6 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Optional, Union, List
from ditk import logging
from easydict import EasyDict
Expand All @@ -9,7 +10,7 @@
import torch
from ding.framework import task, OnlineRLContext
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
from ding.utils import set_pkg_seed
Expand Down Expand Up @@ -62,6 +63,8 @@ class PPOF:
'Hopper-v3',
'HalfCheetah-v3',
'Walker2d-v3',
# rlhf
'chat'
]
"""
Overview:
Expand Down Expand Up @@ -170,6 +173,8 @@ def __init__(
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
elif action_space is None:
pass
else:
action_shape = action_space.shape

Expand All @@ -191,7 +196,11 @@ def __init__(
popart_head=True,
**self.cfg.model
)
self.policy = PPOFPolicy(self.cfg, model=model)
if self.cfg.chat_data:
orig_model = copy.deepcopy(model)
else:
orig_model = None
self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model)
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
Expand Down Expand Up @@ -246,10 +255,14 @@ def train(
pass

with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(ppof_adv_estimator(self.policy))
if self.policy._cfg.chat_data:
# task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
# task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
else:
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(
wandb_online_logger(
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .functional import *
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
Expand Down
61 changes: 61 additions & 0 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
Expand Down Expand Up @@ -190,4 +191,64 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
break


class ChatCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
"""

def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
return super(ChatCollector, cls).__new__(cls)

def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
"""
Arguments:
- seed (:obj:`int`): Random seed.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
self.env = env
self.env.seed(seed)
self.env.launch()
self.env = self.env._envs[0]
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len

def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
device = self.policy._device

obs = ttorch.as_tensor(self.env.last_batch['text_vec'])
batch_size = obs.shape[0]
obs = obs.to(device)

total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T]
for _ in range(self.policy._cfg.answers_per_question):
_, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
for i in range(batch_size):
total_action[i].append(copy.deepcopy(inference_output[i]))

mask, resp, rew = self.env.step(total_action)
ctx.env_step += 1
ctx.env_episode += 1

train_data = {}
train_data['obs'] = resp # [B x answer-per-question, T]
train_data['reward'] = rew # [B x answer-per-question, ]
train_data['mask'] = mask # [B x answer-per-question, T]

ctx.train_data = ttorch.as_tensor(train_data)


# TODO battle collector
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
from .utils import create_model, top_p_logits
15 changes: 15 additions & 0 deletions ding/model/common/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import torch
from ding.model.common.utils import top_p_logits


@pytest.mark.unittest
class TestUtils:

def test_top_p_logits(self):
test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]])

gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]])

pred_logit = top_p_logits(test_logit)
assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8
26 changes: 26 additions & 0 deletions ding/model/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
import_module(cfg.pop('import_names', []))
# here we must use the pop opeartion to ensure compatibility
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)


def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1):
"""
Overview:
Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \
values are masked.
Arguments:
- logits (:obj:`torch.Tensor`): The input logits for top-p sampling.
- topp (:obj:`float`): The top-p value, such as 0.9.
- filter_value (:obj:`float`): The value for masked logits in output, default as 0.
- min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \
will not be masked.)
Returns:
- cum_logits (:obj:`torch.Tensor`): The output logits after masking.
"""
cum_logits = logits.clone()
if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[..., :min_topk] = False
# Remove tokens with cumulative top_p above the threshold
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
return cum_logits
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .language_transformer import LanguageTransformer
from .lm_vac import LlamaVAC
# algorithm-specific
from .pg import PG
from .ppg import PPG
Expand Down
Loading