@@ -1202,41 +1202,98 @@ def _dummy_run(
1202
1202
self ,
1203
1203
num_tokens : int ,
1204
1204
) -> torch .Tensor :
1205
- model = self .model
1206
- if self .is_multimodal_model :
1207
- input_ids = None
1208
- inputs_embeds = self .inputs_embeds [:num_tokens ]
1209
- else :
1210
- input_ids = self .input_ids [:num_tokens ]
1211
- inputs_embeds = None
1212
- if self .uses_mrope :
1213
- positions = self .mrope_positions [:, :num_tokens ]
1214
- else :
1215
- positions = self .positions [:num_tokens ]
1216
1205
1217
- if get_pp_group (). is_first_rank :
1218
- intermediate_tensors = None
1219
- else :
1220
- if self .intermediate_tensors is None :
1221
- self .intermediate_tensors = (
1222
- self . model . make_empty_intermediate_tensors (
1223
- batch_size = self . max_num_tokens ,
1224
- dtype = self . model_config . dtype ,
1225
- device = self . device ))
1226
- intermediate_tensors = IntermediateTensors ({
1227
- k : v [: num_tokens ]
1228
- for k , v in self . intermediate_tensors . items ()
1229
- } )
1206
+ # Set num_scheduled_tokens based on num_tokens and max_num_seqs
1207
+ # for dummy run with LoRA so that the num_reqs collectively
1208
+ # has num_tokens in total.
1209
+ assert num_tokens <= self .scheduler_config . max_num_batched_tokens
1210
+ max_num_reqs = self .scheduler_config . max_num_seqs
1211
+ num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
1212
+ min_tokens_per_req = num_tokens // num_reqs
1213
+ num_scheduled_tokens_list = [ min_tokens_per_req ] * num_reqs
1214
+ num_scheduled_tokens_list [ - 1 ] += num_tokens % num_reqs
1215
+ assert sum ( num_scheduled_tokens_list ) == num_tokens
1216
+ assert len ( num_scheduled_tokens_list ) == num_reqs
1217
+ num_scheduled_tokens = np . array ( num_scheduled_tokens_list ,
1218
+ dtype = np . int32 )
1230
1219
1231
- with set_forward_context (None , self .vllm_config ,
1232
- num_tokens = num_tokens ):
1233
- hidden_states = model (
1234
- input_ids = input_ids ,
1235
- positions = positions ,
1236
- intermediate_tensors = intermediate_tensors ,
1237
- inputs_embeds = inputs_embeds ,
1238
- )
1239
- return hidden_states
1220
+ with self .maybe_dummy_run_with_lora (self .lora_config ,
1221
+ num_scheduled_tokens ):
1222
+ model = self .model
1223
+ if self .is_multimodal_model :
1224
+ input_ids = None
1225
+ inputs_embeds = self .inputs_embeds [:num_tokens ]
1226
+ else :
1227
+ input_ids = self .input_ids [:num_tokens ]
1228
+ inputs_embeds = None
1229
+ if self .uses_mrope :
1230
+ positions = self .mrope_positions [:, :num_tokens ]
1231
+ else :
1232
+ positions = self .positions [:num_tokens ]
1233
+
1234
+ if get_pp_group ().is_first_rank :
1235
+ intermediate_tensors = None
1236
+ else :
1237
+ if self .intermediate_tensors is None :
1238
+ self .intermediate_tensors = (
1239
+ self .model .make_empty_intermediate_tensors (
1240
+ batch_size = self .max_num_tokens ,
1241
+ dtype = self .model_config .dtype ,
1242
+ device = self .device ))
1243
+ intermediate_tensors = IntermediateTensors ({
1244
+ k : v [:num_tokens ]
1245
+ for k , v in self .intermediate_tensors .items ()
1246
+ })
1247
+
1248
+ with set_forward_context (None ,
1249
+ self .vllm_config ,
1250
+ num_tokens = num_tokens ):
1251
+ hidden_states = model (
1252
+ input_ids = input_ids ,
1253
+ positions = positions ,
1254
+ intermediate_tensors = intermediate_tensors ,
1255
+ inputs_embeds = inputs_embeds ,
1256
+ )
1257
+
1258
+ logit_indices = np .cumsum (num_scheduled_tokens ) - 1
1259
+ return hidden_states [logit_indices ]
1260
+
1261
+ @torch .inference_mode ()
1262
+ def _dummy_sampler_run (
1263
+ self ,
1264
+ hidden_states : torch .Tensor ,
1265
+ ) -> torch .Tensor :
1266
+
1267
+ logits = self .model .compute_logits (hidden_states , None )
1268
+ num_reqs = logits .size (0 )
1269
+
1270
+ dummy_tensors = lambda v : torch .full (
1271
+ (num_reqs , ), v , device = self .device )
1272
+
1273
+ dummy_metadata = SamplingMetadata (
1274
+ temperature = dummy_tensors (0.5 ),
1275
+ all_greedy = False ,
1276
+ all_random = False ,
1277
+ top_p = dummy_tensors (0.9 ),
1278
+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1279
+ min_p = None ,
1280
+ generators = {},
1281
+ max_num_logprobs = None ,
1282
+ no_penalties = True ,
1283
+ prompt_token_ids = None ,
1284
+ frequency_penalties = dummy_tensors (0.1 ),
1285
+ presence_penalties = dummy_tensors (0.1 ),
1286
+ repetition_penalties = dummy_tensors (0.1 ),
1287
+ output_token_ids = [[] for _ in range (num_reqs )],
1288
+ min_tokens = {},
1289
+ logit_bias = [None for _ in range (num_reqs )],
1290
+ allowed_token_ids_mask = None ,
1291
+ bad_words_token_ids = {},
1292
+ )
1293
+ sampler_output = self .model .sample (logits = logits ,
1294
+ sampling_metadata = dummy_metadata )
1295
+
1296
+ return sampler_output
1240
1297
1241
1298
def profile_run (self ) -> None :
1242
1299
# Profile with multimodal encoder & encoder cache.
@@ -1332,60 +1389,14 @@ def profile_run(self) -> None:
1332
1389
# Cache the dummy encoder outputs.
1333
1390
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
1334
1391
1335
- # For profile, have maximum num_reqs and that collectively have
1336
- # maximum num_tokens.
1337
- num_reqs = self .scheduler_config .max_num_seqs
1338
- num_tokens = self .max_num_tokens
1339
- min_tokens_per_req = num_tokens // num_reqs
1340
-
1341
- num_scheduled_tokens_list = [min_tokens_per_req ] * num_reqs
1342
- num_scheduled_tokens_list [- 1 ] += num_tokens % num_reqs
1343
- assert sum (num_scheduled_tokens_list ) == num_tokens
1344
- assert len (num_scheduled_tokens_list ) == num_reqs
1345
-
1346
- num_scheduled_tokens = np .array (num_scheduled_tokens_list ,
1347
- dtype = np .int32 )
1348
- logit_indices = np .cumsum (num_scheduled_tokens ) - 1
1349
-
1350
- with self .maybe_profile_with_lora (self .lora_config ,
1351
- num_scheduled_tokens ):
1352
- # Trigger compilation for general shape.
1353
- hidden_states = self ._dummy_run (self .max_num_tokens )
1354
- if get_pp_group ().is_last_rank :
1355
- hidden_states = hidden_states [logit_indices ]
1356
- logits = self .model .compute_logits (hidden_states , None )
1357
- dummy_tensors = lambda v : torch .full (
1358
- (num_reqs , ), v , device = self .device )
1359
- dummy_metadata = SamplingMetadata (
1360
- temperature = dummy_tensors (0.5 ),
1361
- all_greedy = False ,
1362
- all_random = False ,
1363
- top_p = dummy_tensors (0.9 ),
1364
- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1365
- min_p = None ,
1366
- generators = {},
1367
- max_num_logprobs = None ,
1368
- no_penalties = True ,
1369
- prompt_token_ids = torch .ones_like (logits ,
1370
- dtype = torch .int64 ),
1371
- frequency_penalties = dummy_tensors (0.1 ),
1372
- presence_penalties = dummy_tensors (0.1 ),
1373
- repetition_penalties = dummy_tensors (0.1 ),
1374
- output_token_ids = [[] for _ in range (num_reqs )],
1375
- min_tokens = {},
1376
- logit_bias = [None for _ in range (num_reqs )],
1377
- allowed_token_ids_mask = None ,
1378
- bad_words_token_ids = {},
1379
- )
1380
- sampler_output = self .model .sample (
1381
- logits = logits , sampling_metadata = dummy_metadata )
1382
- else :
1383
- logits = None
1384
- sampler_output = None
1385
- dummy_metadata = None
1386
- torch .cuda .synchronize ()
1387
- del hidden_states , logits , sampler_output , dummy_metadata
1388
- self .encoder_cache .clear ()
1392
+ hidden_states = self ._dummy_run (self .max_num_tokens )
1393
+ if get_pp_group ().is_last_rank :
1394
+ sampler_output = self ._dummy_sampler_run (hidden_states )
1395
+ else :
1396
+ sampler_output = None
1397
+ torch .cuda .synchronize ()
1398
+ del hidden_states , sampler_output
1399
+ self .encoder_cache .clear ()
1389
1400
gc .collect ()
1390
1401
1391
1402
def capture_model (self ) -> None :
0 commit comments