Skip to content

Commit 10d84ca

Browse files
authored
fix return_token_log_probs on vLLM > 0.3.3 endpoints (#498)
* fix return_token_log_probs * fix fr * undo extra change
1 parent 040622a commit 10d84ca

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
vllm==0.3.3
1+
vllm==0.4.0.post1
22
pydantic>=2.0

model-engine/model_engine_server/inference/vllm/vllm_server.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import signal
55
import subprocess
66
import traceback
7-
from typing import AsyncGenerator
7+
from typing import AsyncGenerator, Dict, List, Optional
88

99
import uvicorn
1010
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
@@ -13,7 +13,9 @@
1313
from vllm.engine.async_llm_engine import AsyncLLMEngine
1414
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
1515
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
16+
from vllm.outputs import CompletionOutput
1617
from vllm.sampling_params import SamplingParams
18+
from vllm.sequence import Logprob
1719
from vllm.utils import random_uuid
1820

1921
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@@ -76,13 +78,12 @@ async def generate(request: Request) -> Response:
7678
async def stream_results() -> AsyncGenerator[str, None]:
7779
last_output_text = ""
7880
async for request_output in results_generator:
81+
log_probs = format_logprobs(request_output)
7982
ret = {
8083
"text": request_output.outputs[-1].text[len(last_output_text) :],
8184
"count_prompt_tokens": len(request_output.prompt_token_ids),
8285
"count_output_tokens": len(request_output.outputs[0].token_ids),
83-
"log_probs": (
84-
request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None
85-
),
86+
"log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None,
8687
"finished": request_output.finished,
8788
}
8889
last_output_text = request_output.outputs[-1].text
@@ -116,7 +117,7 @@ async def abort_request() -> None:
116117
"text": final_output.outputs[0].text,
117118
"count_prompt_tokens": len(final_output.prompt_token_ids),
118119
"count_output_tokens": len(final_output.outputs[0].token_ids),
119-
"log_probs": final_output.outputs[0].logprobs,
120+
"log_probs": format_logprobs(final_output),
120121
"tokens": tokens,
121122
}
122123
return Response(content=json.dumps(ret))
@@ -166,6 +167,18 @@ def debug(sig, frame):
166167
i.interact(message)
167168

168169

170+
def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]:
171+
"""Given a request output, format the logprobs if they exist."""
172+
output_logprobs = request_output.outputs[0].logprobs
173+
if output_logprobs is None:
174+
return None
175+
176+
def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
177+
return {k: v.logprob for k, v in logprobs.items()}
178+
179+
return [extract_logprobs(logprobs) for logprobs in output_logprobs]
180+
181+
169182
if __name__ == "__main__":
170183
check_unknown_startup_memory_usage()
171184
parser = argparse.ArgumentParser()

model-engine/tests/unit/inference/conftest.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,37 @@ def create_batch_completions_request_content():
5858

5959
@pytest.fixture
6060
def create_vllm_request_outputs():
61+
class Logprob:
62+
"""mock, from https://github.com/vllm-project/vllm/blob/v0.4.1/vllm/sequence.py#L18"""
63+
64+
def __init__(self, logprob: float):
65+
self.logprob = logprob
66+
6167
mock_vllm_request_output1 = MagicMock()
6268
mock_vllm_request_output1.outputs = [
6369
MagicMock(text="text1"),
6470
]
6571
mock_vllm_request_output1.prompt_token_ids = [1, 2, 3]
6672
mock_vllm_request_output1.outputs[0].token_ids = [4]
67-
mock_vllm_request_output1.outputs[0].logprobs = [{4: 0.1}]
73+
mock_vllm_request_output1.outputs[0].logprobs = [{4: Logprob(0.1)}]
6874

6975
mock_vllm_request_output2 = MagicMock()
7076
mock_vllm_request_output2.outputs = [
7177
MagicMock(text="text1 text2"),
7278
]
7379
mock_vllm_request_output2.prompt_token_ids = [1, 2, 3]
7480
mock_vllm_request_output2.outputs[0].token_ids = [4, 5]
75-
mock_vllm_request_output2.outputs[0].logprobs = [{4: 0.1, 5: 0.2}]
81+
mock_vllm_request_output2.outputs[0].logprobs = [{4: Logprob(0.1), 5: Logprob(0.2)}]
7682

7783
mock_vllm_request_output3 = MagicMock()
7884
mock_vllm_request_output3.outputs = [
7985
MagicMock(text="text1 text2 text3"),
8086
]
8187
mock_vllm_request_output3.prompt_token_ids = [1, 2, 3]
8288
mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6]
83-
mock_vllm_request_output3.outputs[0].logprobs = [{4: 0.1, 5: 0.2, 6: 0.3}]
89+
mock_vllm_request_output3.outputs[0].logprobs = [
90+
{4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)}
91+
]
8492
return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3]
8593

8694

0 commit comments

Comments
 (0)