Skip to content

Commit 1fc973c

Browse files
[V1][Core] Fix memory issue with logits & sampling (vllm-project#14508)
Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent c982ac5 commit 1fc973c

File tree

5 files changed

+139
-91
lines changed

5 files changed

+139
-91
lines changed

tests/basic_correctness/test_cumem.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
142142
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
143143
# now the memory usage is mostly cudagraph memory pool,
144144
# and it should be less than the model weights (1B model, 2GiB weights)
145-
assert used_bytes < 2 * GiB_bytes
145+
146+
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
147+
# is captured but cannot be releasesd from PyTorch due to a known bug,
148+
# therefore high memory usage after `llm.sleep` is called is expected.
149+
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
150+
# in V1.
151+
if use_v1:
152+
assert used_bytes < 7 * GiB_bytes
153+
else:
154+
assert used_bytes < 2 * GiB_bytes
146155

147156
llm.wake_up()
148157
output2 = llm.generate(prompt, sampling_params)

vllm/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -3525,6 +3525,11 @@ def _set_cudagraph_sizes(self):
35253525
not self.model_config.enforce_eager:
35263526
batch_size_capture_list = [1, 2, 4
35273527
] + [i for i in range(8, 513, 8)]
3528+
max_num_tokens = self.scheduler_config.max_num_batched_tokens
3529+
batch_size_capture_list = [
3530+
size for size in batch_size_capture_list
3531+
if size <= max_num_tokens
3532+
]
35283533

35293534
self.compilation_config.init_with_cudagraph_sizes(
35303535
batch_size_capture_list)

vllm/v1/worker/gpu_model_runner.py

+98-87
Original file line numberDiff line numberDiff line change
@@ -1202,41 +1202,98 @@ def _dummy_run(
12021202
self,
12031203
num_tokens: int,
12041204
) -> torch.Tensor:
1205-
model = self.model
1206-
if self.is_multimodal_model:
1207-
input_ids = None
1208-
inputs_embeds = self.inputs_embeds[:num_tokens]
1209-
else:
1210-
input_ids = self.input_ids[:num_tokens]
1211-
inputs_embeds = None
1212-
if self.uses_mrope:
1213-
positions = self.mrope_positions[:, :num_tokens]
1214-
else:
1215-
positions = self.positions[:num_tokens]
12161205

1217-
if get_pp_group().is_first_rank:
1218-
intermediate_tensors = None
1219-
else:
1220-
if self.intermediate_tensors is None:
1221-
self.intermediate_tensors = (
1222-
self.model.make_empty_intermediate_tensors(
1223-
batch_size=self.max_num_tokens,
1224-
dtype=self.model_config.dtype,
1225-
device=self.device))
1226-
intermediate_tensors = IntermediateTensors({
1227-
k: v[:num_tokens]
1228-
for k, v in self.intermediate_tensors.items()
1229-
})
1206+
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1207+
# for dummy run with LoRA so that the num_reqs collectively
1208+
# has num_tokens in total.
1209+
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
1210+
max_num_reqs = self.scheduler_config.max_num_seqs
1211+
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
1212+
min_tokens_per_req = num_tokens // num_reqs
1213+
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
1214+
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
1215+
assert sum(num_scheduled_tokens_list) == num_tokens
1216+
assert len(num_scheduled_tokens_list) == num_reqs
1217+
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
1218+
dtype=np.int32)
12301219

1231-
with set_forward_context(None, self.vllm_config,
1232-
num_tokens=num_tokens):
1233-
hidden_states = model(
1234-
input_ids=input_ids,
1235-
positions=positions,
1236-
intermediate_tensors=intermediate_tensors,
1237-
inputs_embeds=inputs_embeds,
1238-
)
1239-
return hidden_states
1220+
with self.maybe_dummy_run_with_lora(self.lora_config,
1221+
num_scheduled_tokens):
1222+
model = self.model
1223+
if self.is_multimodal_model:
1224+
input_ids = None
1225+
inputs_embeds = self.inputs_embeds[:num_tokens]
1226+
else:
1227+
input_ids = self.input_ids[:num_tokens]
1228+
inputs_embeds = None
1229+
if self.uses_mrope:
1230+
positions = self.mrope_positions[:, :num_tokens]
1231+
else:
1232+
positions = self.positions[:num_tokens]
1233+
1234+
if get_pp_group().is_first_rank:
1235+
intermediate_tensors = None
1236+
else:
1237+
if self.intermediate_tensors is None:
1238+
self.intermediate_tensors = (
1239+
self.model.make_empty_intermediate_tensors(
1240+
batch_size=self.max_num_tokens,
1241+
dtype=self.model_config.dtype,
1242+
device=self.device))
1243+
intermediate_tensors = IntermediateTensors({
1244+
k: v[:num_tokens]
1245+
for k, v in self.intermediate_tensors.items()
1246+
})
1247+
1248+
with set_forward_context(None,
1249+
self.vllm_config,
1250+
num_tokens=num_tokens):
1251+
hidden_states = model(
1252+
input_ids=input_ids,
1253+
positions=positions,
1254+
intermediate_tensors=intermediate_tensors,
1255+
inputs_embeds=inputs_embeds,
1256+
)
1257+
1258+
logit_indices = np.cumsum(num_scheduled_tokens) - 1
1259+
return hidden_states[logit_indices]
1260+
1261+
@torch.inference_mode()
1262+
def _dummy_sampler_run(
1263+
self,
1264+
hidden_states: torch.Tensor,
1265+
) -> torch.Tensor:
1266+
1267+
logits = self.model.compute_logits(hidden_states, None)
1268+
num_reqs = logits.size(0)
1269+
1270+
dummy_tensors = lambda v: torch.full(
1271+
(num_reqs, ), v, device=self.device)
1272+
1273+
dummy_metadata = SamplingMetadata(
1274+
temperature=dummy_tensors(0.5),
1275+
all_greedy=False,
1276+
all_random=False,
1277+
top_p=dummy_tensors(0.9),
1278+
top_k=dummy_tensors(logits.size(1) - 1),
1279+
min_p=None,
1280+
generators={},
1281+
max_num_logprobs=None,
1282+
no_penalties=True,
1283+
prompt_token_ids=None,
1284+
frequency_penalties=dummy_tensors(0.1),
1285+
presence_penalties=dummy_tensors(0.1),
1286+
repetition_penalties=dummy_tensors(0.1),
1287+
output_token_ids=[[] for _ in range(num_reqs)],
1288+
min_tokens={},
1289+
logit_bias=[None for _ in range(num_reqs)],
1290+
allowed_token_ids_mask=None,
1291+
bad_words_token_ids={},
1292+
)
1293+
sampler_output = self.model.sample(logits=logits,
1294+
sampling_metadata=dummy_metadata)
1295+
1296+
return sampler_output
12401297

12411298
def profile_run(self) -> None:
12421299
# Profile with multimodal encoder & encoder cache.
@@ -1332,60 +1389,14 @@ def profile_run(self) -> None:
13321389
# Cache the dummy encoder outputs.
13331390
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
13341391

1335-
# For profile, have maximum num_reqs and that collectively have
1336-
# maximum num_tokens.
1337-
num_reqs = self.scheduler_config.max_num_seqs
1338-
num_tokens = self.max_num_tokens
1339-
min_tokens_per_req = num_tokens // num_reqs
1340-
1341-
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
1342-
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
1343-
assert sum(num_scheduled_tokens_list) == num_tokens
1344-
assert len(num_scheduled_tokens_list) == num_reqs
1345-
1346-
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
1347-
dtype=np.int32)
1348-
logit_indices = np.cumsum(num_scheduled_tokens) - 1
1349-
1350-
with self.maybe_profile_with_lora(self.lora_config,
1351-
num_scheduled_tokens):
1352-
# Trigger compilation for general shape.
1353-
hidden_states = self._dummy_run(self.max_num_tokens)
1354-
if get_pp_group().is_last_rank:
1355-
hidden_states = hidden_states[logit_indices]
1356-
logits = self.model.compute_logits(hidden_states, None)
1357-
dummy_tensors = lambda v: torch.full(
1358-
(num_reqs, ), v, device=self.device)
1359-
dummy_metadata = SamplingMetadata(
1360-
temperature=dummy_tensors(0.5),
1361-
all_greedy=False,
1362-
all_random=False,
1363-
top_p=dummy_tensors(0.9),
1364-
top_k=dummy_tensors(logits.size(1) - 1),
1365-
min_p=None,
1366-
generators={},
1367-
max_num_logprobs=None,
1368-
no_penalties=True,
1369-
prompt_token_ids=torch.ones_like(logits,
1370-
dtype=torch.int64),
1371-
frequency_penalties=dummy_tensors(0.1),
1372-
presence_penalties=dummy_tensors(0.1),
1373-
repetition_penalties=dummy_tensors(0.1),
1374-
output_token_ids=[[] for _ in range(num_reqs)],
1375-
min_tokens={},
1376-
logit_bias=[None for _ in range(num_reqs)],
1377-
allowed_token_ids_mask=None,
1378-
bad_words_token_ids={},
1379-
)
1380-
sampler_output = self.model.sample(
1381-
logits=logits, sampling_metadata=dummy_metadata)
1382-
else:
1383-
logits = None
1384-
sampler_output = None
1385-
dummy_metadata = None
1386-
torch.cuda.synchronize()
1387-
del hidden_states, logits, sampler_output, dummy_metadata
1388-
self.encoder_cache.clear()
1392+
hidden_states = self._dummy_run(self.max_num_tokens)
1393+
if get_pp_group().is_last_rank:
1394+
sampler_output = self._dummy_sampler_run(hidden_states)
1395+
else:
1396+
sampler_output = None
1397+
torch.cuda.synchronize()
1398+
del hidden_states, sampler_output
1399+
self.encoder_cache.clear()
13891400
gc.collect()
13901401

13911402
def capture_model(self) -> None:

vllm/v1/worker/gpu_worker.py

+23
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def init_device(self):
119119
self.model_runner: GPUModelRunner = GPUModelRunner(
120120
self.vllm_config, self.device)
121121

122+
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
123+
# to hijack tensor allocation.
122124
def load_model(self) -> None:
123125
if self.vllm_config.model_config.enable_sleep_mode:
124126
allocator = CuMemAllocator.get_instance()
@@ -211,6 +213,27 @@ def compile_or_warm_up_model(self) -> None:
211213
self.model_runner._dummy_run(size)
212214
if not self.model_config.enforce_eager:
213215
self.model_runner.capture_model()
216+
217+
# Warm up sampler and preallocate memory buffer for logits and other
218+
# sampling related tensors of max possible shape to avoid memory
219+
# fragmentation issue.
220+
# NOTE: This is called after `capture_model` on purpose to prevent
221+
# memory buffers from being cleared by `torch.cuda.empty_cache`.
222+
try:
223+
max_num_reqs = min(self.scheduler_config.max_num_seqs,
224+
self.scheduler_config.max_num_batched_tokens)
225+
self.model_runner._dummy_sampler_run(
226+
hidden_states=self.model_runner._dummy_run(
227+
num_tokens=max_num_reqs))
228+
except RuntimeError as e:
229+
if 'out of memory' in str(e):
230+
raise RuntimeError(
231+
"CUDA out of memory occurred when warming up sampler. "
232+
"Please try lowering `gpu_memory_utilization` when "
233+
"initializing the engine.") from None
234+
else:
235+
raise e
236+
214237
# Reset the seed to ensure that the random state is not affected by
215238
# the model initialization and profiling.
216239
set_random_seed(self.model_config.seed)

vllm/v1/worker/lora_model_runner_mixin.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def set_active_loras(self, input_batch: InputBatch,
8383
lora_requests)
8484

8585
@contextmanager
86-
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
87-
num_scheduled_tokens: np.ndarray):
86+
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
87+
num_scheduled_tokens: np.ndarray):
8888
if lora_config is None:
8989
yield
9090
else:
@@ -145,4 +145,4 @@ def pin_lora(self, lora_id: int) -> bool:
145145
def list_loras(self) -> set[int]:
146146
if not self.lora_manager:
147147
raise RuntimeError("LoRA is not enabled.")
148-
return self.lora_manager.list_adapters()
148+
return self.lora_manager.list_adapters()

0 commit comments

Comments
 (0)