Skip to content

Commit df7416f

Browse files
da-xnjhill
andcommitted
simple_connector.py: more efficient use of GPU memory in send
This makes more efficient use of GPU memory by preallocating the `keys` and `values` and copying onto them, rather than using `torch.cat` on sub-tensors which would otherwise take double the GPU memory. Signed-off-by: Dan Aloni <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 0a049c7 commit df7416f

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

vllm/distributed/kv_transfer/kv_connector/simple_connector.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,15 @@ def send_kv_caches_and_hidden_states(
208208

209209
current_tokens = input_tokens_tensor[start_pos:end_pos]
210210

211-
keys, values = [], []
211+
# Preallocate the tensors
212+
keys = kv_caches[0].new_empty(
213+
(end_layer - start_layer, end_pos - start_pos, num_heads,
214+
head_size))
215+
values = torch.empty_like(keys)
212216

213-
for layer_id in range(start_layer, end_layer):
214-
kv_cache = kv_caches[layer_id - start_layer]
217+
# Copy the relevant parts from kvcache
218+
for layer_idx in range(end_layer - start_layer):
219+
kv_cache = kv_caches[layer_idx]
215220

216221
if self.is_deepseek_mla and self.use_mla_opt:
217222
key_cache = kv_cache.reshape(-1, num_heads, head_size)
@@ -222,11 +227,8 @@ def send_kv_caches_and_hidden_states(
222227

223228
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
224229

225-
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
226-
values.append(value_cache[current_slot_mapping].unsqueeze(0))
227-
228-
keys = torch.cat(keys, dim=0)
229-
values = torch.cat(values, dim=0)
230+
keys[layer_idx].copy_(key_cache[current_slot_mapping])
231+
values[layer_idx].copy_(value_cache[current_slot_mapping])
230232

231233
self.insert(current_tokens,
232234
torch.ones_like(current_tokens,

0 commit comments

Comments
 (0)