|
4 | 4 | import signal
|
5 | 5 | import subprocess
|
6 | 6 | import traceback
|
7 |
| -from typing import AsyncGenerator |
| 7 | +from typing import AsyncGenerator, Dict, List, Optional |
8 | 8 |
|
9 | 9 | import uvicorn
|
10 | 10 | from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
|
|
13 | 13 | from vllm.engine.async_llm_engine import AsyncLLMEngine
|
14 | 14 | from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
|
15 | 15 | from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
|
| 16 | +from vllm.outputs import CompletionOutput |
16 | 17 | from vllm.sampling_params import SamplingParams
|
| 18 | +from vllm.sequence import Logprob |
17 | 19 | from vllm.utils import random_uuid
|
18 | 20 |
|
19 | 21 | TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
@@ -76,13 +78,12 @@ async def generate(request: Request) -> Response:
|
76 | 78 | async def stream_results() -> AsyncGenerator[str, None]:
|
77 | 79 | last_output_text = ""
|
78 | 80 | async for request_output in results_generator:
|
| 81 | + log_probs = format_logprobs(request_output) |
79 | 82 | ret = {
|
80 | 83 | "text": request_output.outputs[-1].text[len(last_output_text) :],
|
81 | 84 | "count_prompt_tokens": len(request_output.prompt_token_ids),
|
82 | 85 | "count_output_tokens": len(request_output.outputs[0].token_ids),
|
83 |
| - "log_probs": ( |
84 |
| - request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None |
85 |
| - ), |
| 86 | + "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, |
86 | 87 | "finished": request_output.finished,
|
87 | 88 | }
|
88 | 89 | last_output_text = request_output.outputs[-1].text
|
@@ -116,7 +117,7 @@ async def abort_request() -> None:
|
116 | 117 | "text": final_output.outputs[0].text,
|
117 | 118 | "count_prompt_tokens": len(final_output.prompt_token_ids),
|
118 | 119 | "count_output_tokens": len(final_output.outputs[0].token_ids),
|
119 |
| - "log_probs": final_output.outputs[0].logprobs, |
| 120 | + "log_probs": format_logprobs(final_output), |
120 | 121 | "tokens": tokens,
|
121 | 122 | }
|
122 | 123 | return Response(content=json.dumps(ret))
|
@@ -166,6 +167,18 @@ def debug(sig, frame):
|
166 | 167 | i.interact(message)
|
167 | 168 |
|
168 | 169 |
|
| 170 | +def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]: |
| 171 | + """Given a request output, format the logprobs if they exist.""" |
| 172 | + output_logprobs = request_output.outputs[0].logprobs |
| 173 | + if output_logprobs is None: |
| 174 | + return None |
| 175 | + |
| 176 | + def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: |
| 177 | + return {k: v.logprob for k, v in logprobs.items()} |
| 178 | + |
| 179 | + return [extract_logprobs(logprobs) for logprobs in output_logprobs] |
| 180 | + |
| 181 | + |
169 | 182 | if __name__ == "__main__":
|
170 | 183 | check_unknown_startup_memory_usage()
|
171 | 184 | parser = argparse.ArgumentParser()
|
|
0 commit comments