Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 11, 2025
1 parent 3df3cf6 commit 76a098c
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 120 deletions.
18 changes: 0 additions & 18 deletions swift/llm/argument/base_args/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,6 @@ class ModelArguments:
# this parameter specifies the path to the locally downloaded repository.
local_repo_path: Optional[str] = None

@staticmethod
def parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Union[str, Dict]:
"""Convert a JSON string or JSON file into a dict"""
# If the value could potentially be a string, it is generally advisable to set strict to False.
if value is None:
value = {}
elif isinstance(value, str):
if os.path.exists(value): # local path
with open(value, 'r', encoding='utf-8') as f:
value = json.load(f)
else: # json str
try:
value = json.loads(value)
except json.JSONDecodeError:
if strict:
logger.error(f"Unable to parse string: '{value}'")
raise
return value

def _init_device_map(self):
"""Prepare device map args"""
Expand Down
82 changes: 0 additions & 82 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,88 +16,6 @@
logger = get_logger()


@dataclass
class LmdeployArguments:
"""
LmdeployArguments is a dataclass that holds the configuration for lmdeploy.
Args:
tp (int): Tensor parallelism size. Default is 1.
session_len(Optional[int]): The session length, default None.
cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8.
quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0.
vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1.
"""

# lmdeploy
tp: int = 1
session_len: Optional[int] = None
cache_max_entry_count: float = 0.8
quant_policy: int = 0 # e.g. 4, 8
vision_batch_size: int = 1 # max_batch_size in VisionConfig

def get_lmdeploy_engine_kwargs(self):
return {
'tp': self.tp,
'session_len': self.session_len,
'cache_max_entry_count': self.cache_max_entry_count,
'quant_policy': self.quant_policy,
'vision_batch_size': self.vision_batch_size
}


@dataclass
class VllmArguments:
"""
VllmArguments is a dataclass that holds the configuration for vllm.
Args:
gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
tensor_parallel_size (int): Tensor parallelism size. Default is 1.
pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
max_num_seqs (int): Maximum number of sequences. Default is 256.
max_model_len (Optional[int]): Maximum model length. Default is None.
disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False.
enforce_eager (bool): Flag to enforce eager execution. Default is False.
limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
"""
# vllm
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
max_num_seqs: int = 256
max_model_len: Optional[int] = None
disable_custom_all_reduce: bool = False
enforce_eager: bool = False
limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}'
vllm_max_lora_rank: int = 16
enable_prefix_caching: bool = False

def __post_init__(self):
self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt)

def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
return {
'gpu_memory_utilization': self.gpu_memory_utilization,
'tensor_parallel_size': self.tensor_parallel_size,
'pipeline_parallel_size': self.pipeline_parallel_size,
'max_num_seqs': self.max_num_seqs,
'max_model_len': self.max_model_len,
'disable_custom_all_reduce': self.disable_custom_all_reduce,
'enforce_eager': self.enforce_eager,
'limit_mm_per_prompt': self.limit_mm_per_prompt,
'max_lora_rank': self.vllm_max_lora_rank,
'enable_lora': len(adapters) > 0,
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.enable_prefix_caching,
}


@dataclass
class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArguments):
"""
Expand Down
31 changes: 17 additions & 14 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from typing import List, Literal, Optional

from swift.llm import MODEL_MAPPING
from swift.utils import get_logger
from swift.utils import get_logger, VllmArguments
from .train_args import TrainArguments

logger = get_logger()


@dataclass
class PPOArguments:
class RewardModelArguments:
reward_model: Optional[str] = None
reward_adapters: List[str] = field(default_factory=list)
reward_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
reward_model_revision: Optional[str] = None

@dataclass
class PPOArguments:
num_ppo_epochs: int = 4
whiten_rewards: bool = False
kl_coef: float = 0.05
Expand All @@ -31,12 +33,20 @@ class PPOArguments:
local_rollout_forward_batch_size: int = 64
num_sample_generations: int = 10
response_length: int = 512
temperature: float = 0.7
missing_eos_penalty: Optional[float] = None


@dataclass
class RLHFArguments(PPOArguments, TrainArguments):
class GRPOArguments(VllmArguments):
num_generations: int = 8 # G in the GRPO paper
max_completion_length: int = 512
reward_funcs: List[str] = field(default_factory=list)
# vLLM in GRPO
use_vllm: bool = False
vllm_device: Optional[str] = 'auto' # 'cuda:1'

@dataclass
class RLHFArguments(PPOArguments, RewardModelArguments, TrainArguments):
"""
RLHFArguments is a dataclass that holds arguments specific to the Reinforcement
Learning with Human Feedback (RLHF) training backend.
Expand All @@ -62,6 +72,7 @@ class RLHFArguments(PPOArguments, TrainArguments):

beta: Optional[float] = None
label_smoothing: float = 0
loss_scale: Optional[str] = None # 'last_round'
# DPO
rpo_alpha: float = 1.
# CPO
Expand All @@ -71,16 +82,8 @@ class RLHFArguments(PPOArguments, TrainArguments):
# KTO
desirable_weight: float = 1.0
undesirable_weight: float = 1.0
# GRPO
num_generations: int = 8 # G in the GRPO paper
max_completion_length: int = 512
reward_funcs: List[str] = field(default_factory=list)
# vLLM in GRPO
use_vllm: bool = False
vllm_device: Optional[str] = 'auto' # 'cuda:1'
vllm_gpu_memory_utilization: float = 0.9
vllm_max_model_len: Optional[int] = None
loss_scale: Optional[str] = None
# PPO/GRPO
temperature: float = 0.7

def __post_init__(self):
self._init_grpo()
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from swift.plugin import LOSS_MAPPING
from swift.trainers import TrainerFactory
from swift.utils import (add_version_to_work_dir, get_logger, get_pai_tensorboard_dir, is_liger_available,
is_local_master, is_mp, is_pai_training_job, use_torchacc)
is_local_master, is_mp, is_pai_training_job, use_torchacc, parse_to_dict)
from .base_args import BaseArguments, to_abspath
from .tuner_args import TunerArguments

Expand Down Expand Up @@ -65,9 +65,9 @@ def __post_init__(self):
else:
self.learning_rate = 1e-4
if self.lr_scheduler_kwargs:
self.lr_scheduler_kwargs = self.parse_to_dict(self.lr_scheduler_kwargs)
self.lr_scheduler_kwargs = parse_to_dict(self.lr_scheduler_kwargs)
if getattr(self, 'gradient_checkpointing_kwargs', None):
self.gradient_checkpointing_kwargs = self.parse_to_dict(self.gradient_checkpointing_kwargs)
self.gradient_checkpointing_kwargs = parse_to_dict(self.gradient_checkpointing_kwargs)
self._init_eval_strategy()


Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from trl import PPOConfig as HfPPOConfig
from trl import RewardConfig as HfRewardConfig

from .arguments import SwiftArgumentsMixin
from .arguments import SwiftArgumentsMixin, VllmArguments


@dataclass
Expand Down Expand Up @@ -42,5 +42,5 @@ class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):


@dataclass
class GRPOConfig(SwiftArgumentsMixin, HfGRPOConfig):
class GRPOConfig(VllmArguments, SwiftArgumentsMixin, HfGRPOConfig):
pass
3 changes: 2 additions & 1 deletion swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
from .torch_utils import (Serializer, activate_parameters, find_all_linears, find_embedding, find_norm,
freeze_parameters, get_model_parameter_info, safe_ddp_context, show_layers, time_synchronize)
from .utils import (add_version_to_work_dir, check_json_format, deep_getattr, find_free_port, get_env_args, lower_bound,
parse_args, patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time,
patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time,
upper_bound)
from .argumens import parse_to_dict, parse_args
101 changes: 101 additions & 0 deletions swift/utils/argumens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@


@dataclass
class VllmArguments:
"""
VllmArguments is a dataclass that holds the configuration for vllm.
Args:
gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
tensor_parallel_size (int): Tensor parallelism size. Default is 1.
pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
max_num_seqs (int): Maximum number of sequences. Default is 256.
max_model_len (Optional[int]): Maximum model length. Default is None.
disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False.
enforce_eager (bool): Flag to enforce eager execution. Default is False.
limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
"""
# vllm
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
max_num_seqs: int = 256
max_model_len: Optional[int] = None
disable_custom_all_reduce: bool = False
enforce_eager: bool = False
limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}'
vllm_max_lora_rank: int = 16
enable_prefix_caching: bool = False

def __post_init__(self):
self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt)

def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
return {
'gpu_memory_utilization': self.gpu_memory_utilization,
'tensor_parallel_size': self.tensor_parallel_size,
'pipeline_parallel_size': self.pipeline_parallel_size,
'max_num_seqs': self.max_num_seqs,
'max_model_len': self.max_model_len,
'disable_custom_all_reduce': self.disable_custom_all_reduce,
'enforce_eager': self.enforce_eager,
'limit_mm_per_prompt': self.limit_mm_per_prompt,
'max_lora_rank': self.vllm_max_lora_rank,
'enable_lora': len(adapters) > 0,
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.enable_prefix_caching,
}


@dataclass
class LmdeployArguments:
"""
LmdeployArguments is a dataclass that holds the configuration for lmdeploy.
Args:
tp (int): Tensor parallelism size. Default is 1.
session_len(Optional[int]): The session length, default None.
cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8.
quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0.
vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1.
"""

# lmdeploy
tp: int = 1
session_len: Optional[int] = None
cache_max_entry_count: float = 0.8
quant_policy: int = 0 # e.g. 4, 8
vision_batch_size: int = 1 # max_batch_size in VisionConfig

def get_lmdeploy_engine_kwargs(self):
return {
'tp': self.tp,
'session_len': self.session_len,
'cache_max_entry_count': self.cache_max_entry_count,
'quant_policy': self.quant_policy,
'vision_batch_size': self.vision_batch_size
}


def parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Union[str, Dict]:
"""Convert a JSON string or JSON file into a dict"""
# If the value could potentially be a string, it is generally advisable to set strict to False.
if value is None:
value = {}
elif isinstance(value, str):
if os.path.exists(value): # local path
with open(value, 'r', encoding='utf-8') as f:
value = json.load(f)
else: # json str
try:
value = json.loads(value)
except json.JSONDecodeError:
if strict:
logger.error(f"Unable to parse string: '{value}'")
raise
return value

0 comments on commit 76a098c

Please sign in to comment.