diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2d1ae341df3..a42df3a5e6a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -81,7 +81,7 @@ LORA_WARMUP_RANK = 8 VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', - 'false').lower() == 'true' + 'true').lower() == 'true' DUMMY_TOKEN_ID = -1 @@ -736,6 +736,8 @@ def __init__( "Speculative decoding is not supported with " "contiguous PA, please set VLLM_CONTIGUOUS_PA=false") # For both multi-step scheduling and delayed sampling + self.is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 self.cached_step_outputs: List[torch.Tensor] = [] self.is_pooler = False # For delayed sampling @@ -1883,9 +1885,7 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - is_single_step = \ - self.vllm_config.scheduler_config.num_scheduler_steps == 1 - if is_prompt or is_single_step: + if is_prompt or self.is_single_step: self.execute_model(inputs, kv_caches, warmup_mode=True) else: # decode with multi-step inputs = dataclasses.replace(inputs, @@ -2451,9 +2451,9 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode - assert not (use_delayed_sampling and num_steps != 1), \ - 'Delayed sampling is not compatible with MSS!' + # Delayed sampling is only supported for single step scheduling + use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode \ + and self.is_single_step and not is_fake_hpu() assert model_input.input_tokens is not None if use_delayed_sampling and not model_input.is_prompt and \ self.is_driver_worker: