Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable split qkv for LLama and GPTBigCode #914

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ class CacheConfig:
prefix caching enabled.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
split_qkv: Whether to split the QKV calculations.
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -1051,6 +1052,7 @@ def __init__(
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
split_qkv: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
Expand All @@ -1062,6 +1064,7 @@ def __init__(
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self.split_qkv = split_qkv
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class EngineArgs:
model_impl: str = "auto"

calculate_kv_scales: Optional[bool] = None
split_qkv: Optional[bool] = False

additional_config: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -1028,6 +1029,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')

parser.add_argument(
'--split-qkv',
action='store_true',
default=EngineArgs.split_qkv,
help='Whether to separate q, k and v calculations.')

parser.add_argument(
"--additional-config",
type=json.loads,
Expand Down Expand Up @@ -1148,6 +1155,7 @@ def create_engine_config(self,
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
split_qkv=self.split_qkv,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
Expand Down
22 changes: 8 additions & 14 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,32 +284,26 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(self.model_config.dtype),
"dtype": str(self.model_config.dtype),
"tensor_parallel_size":
self.parallel_config.tensor_parallel_size,
"block_size":
self.cache_config.block_size,
"block_size": self.cache_config.block_size,
"gpu_memory_utilization":
self.cache_config.gpu_memory_utilization,

# Quantization
"quantization":
self.model_config.quantization,
"kv_cache_dtype":
str(self.cache_config.cache_dtype),
"quantization": self.model_config.quantization,
"kv_cache_dtype": str(self.cache_config.cache_dtype),

# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_lora": bool(self.lora_config),
"enable_prompt_adapter": bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
self.model_config.enforce_eager,
"enforce_eager": self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
"split_qk_v": self.cache_config.split_qkv,
})

if self.tokenizer:
Expand Down
67 changes: 67 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,73 @@ def weight_loader(self,
param_data.copy_(loaded_weight)


class SplitQKVParallelLinear(torch.nn.Module):

def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1

input_size = self.hidden_size
q_size = self.num_heads * self.head_size * tp_size
kv_size = self.num_kv_heads * self.head_size * tp_size

self.q_proj = ColumnParallelLinear(input_size=input_size,
output_size=q_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
self.k_proj = ColumnParallelLinear(input_size=input_size,
output_size=kv_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
self.v_proj = ColumnParallelLinear(input_size=input_size,
output_size=kv_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)

def forward(self, input_):
q, output_bias = self.q_proj(input_)
k, _ = self.k_proj(input_)
v, _ = self.v_proj(input_)
return q, k, v, output_bias


class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.

Expand Down
105 changes: 77 additions & 28 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
RowParallelLinear,
SplitQKVParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand Down Expand Up @@ -78,15 +79,25 @@ def __init__(
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)

self.split_qkv = cache_config.split_qkv
if self.split_qkv:
self.c_attn = SplitQKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
else:
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
Expand All @@ -107,14 +118,17 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1,
)
if self.split_qkv:
q, k, v, _ = self.c_attn(hidden_states)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1,
)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
Expand Down Expand Up @@ -280,6 +294,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config

self.config = config
self.lora_config = lora_config
Expand All @@ -302,6 +317,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
self.split_qkv = cache_config.split_qkv

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
Expand Down Expand Up @@ -339,6 +355,13 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (new_name, orig_name, shard_id)
(".c_attn.q_proj", ".c_attn", "q"),
(".c_attn.k_proj", ".c_attn", "k"),
(".c_attn.v_proj", ".c_attn", "v"),
]

params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
Expand All @@ -350,15 +373,41 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')

if self.split_qkv and ".c_attn" in name:
attn_block = self.transformer.h[
self.transformer.start_layer].attn
weight = {}
weight['q'], weight['k'], weight['v'] = loaded_weight.split(
[
attn_block.num_heads * attn_block.head_dim *
attn_block.tensor_model_parallel_world_size,
attn_block.num_kv_heads * attn_block.head_dim *
attn_block.tensor_model_parallel_world_size,
attn_block.num_kv_heads * attn_block.head_dim *
attn_block.tensor_model_parallel_world_size,
],
dim=0)
for param_name, weight_name, shard_id in stacked_params_mapping:
new_name = name.replace(weight_name, param_name)
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight[shard_id])
loaded_params.add(new_name)
continue
else:
weight_loader(param, loaded_weight)
loaded_params.add(name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if ("c_attn.input_scale" in name
or "c_attn.weight_scale" in name):
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Loading