Skip to content

Commit 87d07b0

Browse files
committed
feat(services): add custom callbacks service
Signed-off-by: Yanik Ammann <[email protected]>
1 parent 1bd0dfe commit 87d07b0

File tree

9 files changed

+176
-20
lines changed

9 files changed

+176
-20
lines changed

custom_callbacks_example.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from fastapi import Request, Response
2+
3+
from vllm_router.services.callbacks_service.custom_callbacks import (
4+
CustomCallbackHandler,
5+
)
6+
7+
8+
class MyCustomCallbackHandler(CustomCallbackHandler):
9+
def pre_request(self, request: Request, request_body: bytes, request_json: any):
10+
if b"coffee" in request_body:
11+
return Response("I'm a teapot", 418)
12+
13+
def post_request(self, request: Request, response_content: bytes):
14+
with open("/tmp/response.txt", "ab") as f:
15+
f.write(response_content)
16+
17+
18+
my_callback_handler_instance = MyCustomCallbackHandler()

docs/source/user_manual/router/cmd.rst

+44
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Logging Options
4444
+++++++++++++++
4545

4646
- ``--log-stats``: Log statistics every 30 seconds.
47+
- ``--callbacks``: The path to the callback instance extending CustomCallbackHandler (e.g. ``my_callbacks.my_callback_handler_instance``).
4748

4849

4950
Build docker image
@@ -78,3 +79,46 @@ You can install the router using the following command:
7879
--engine-stats-interval 10 \
7980
--log-stats \
8081
--routing-logic roundrobin
82+
83+
84+
Hooking into custom callbacks
85+
-----------------------------
86+
87+
The router can be extended to add custom callbacks at various points in the request lifecycle.
88+
89+
For this you will need to create a custom callback handler instance, implementing at least one of the available callback methods. You can find all available callbacks along with detailed descriptions in the abstract `CustomCallbackHandler <https://github.com/vllm-project/production-stack/tree/main/src/vllm_router/services/callbacks_service/custom_callbacks.py>`_ class.
90+
91+
.. code-block:: python
92+
93+
# my_callbacks.py
94+
95+
from fastapi import Request, Response
96+
97+
from vllm_router.services.callbacks_service.custom_callbacks import CustomCallbackHandler
98+
99+
100+
class MyCustomCallbackHandler(CustomCallbackHandler):
101+
def pre_request(self, request: Request, request_body: bytes, request_json: any) -> Response | None:
102+
"""
103+
Receives the request object before it gets proxied.
104+
"""
105+
if b"coffee" in request_body:
106+
return Response("I'm a teapot", 418)
107+
108+
def post_request(self, request: Request, response_content: bytes) -> None:
109+
"""
110+
Is executed as a background task, receives the request object
111+
and the complete response_content.
112+
"""
113+
with open("/tmp/response.txt", "ab") as f:
114+
f.write(response_content)
115+
116+
117+
my_callback_handler_instance = MyCustomCallbackHandler()
118+
119+
120+
You can pass the instance to the router with the filename first (without the file ending), followed by the instance, separated by a dot like this:
121+
122+
.. code-block:: bash
123+
124+
vllm-router ... --callbacks my_callbacks.my_callback_handler_instance

src/vllm_router/app.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
initialize_dynamic_config_watcher,
1212
)
1313
from vllm_router.experimental import get_feature_gates, initialize_feature_gates
14+
from vllm_router.services.callbacks_service.callbacks import initialize_custom_callbacks
1415

1516
try:
1617
# Semantic cache integration
@@ -206,6 +207,9 @@ def initialize_all(app: FastAPI, args):
206207
args.dynamic_config_json, 10, init_config, app
207208
)
208209

210+
if args.callbacks:
211+
initialize_custom_callbacks(args.callbacks, app)
212+
209213

210214
app = FastAPI(lifespan=lifespan)
211215
app.include_router(main_router)

src/vllm_router/parsers/parser.py

+6
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ def parse_args():
108108
default=None,
109109
help="The key (in the header) to identify a session.",
110110
)
111+
parser.add_argument(
112+
"--callbacks",
113+
type=str,
114+
default=None,
115+
help="Path to the callback instance extending CustomCallbackHandler. Consists of <file path without .py ending>.<instance variable name>.",
116+
)
111117

112118
# Request rewriter arguments
113119
parser.add_argument(

src/vllm_router/routers/main_router.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from fastapi import APIRouter, Request
3+
from fastapi import APIRouter, BackgroundTasks, Request
44
from fastapi.responses import JSONResponse, Response
55

66
from vllm_router.dynamic_config import get_dynamic_config_watcher
@@ -40,7 +40,7 @@
4040

4141

4242
@main_router.post("/v1/chat/completions")
43-
async def route_chat_completion(request: Request):
43+
async def route_chat_completion(request: Request, background_tasks: BackgroundTasks):
4444
if semantic_cache_available:
4545
# Check if the request can be served from the semantic cache
4646
logger.debug("Received chat completion request, checking semantic cache")
@@ -51,37 +51,39 @@ async def route_chat_completion(request: Request):
5151
return cache_response
5252

5353
logger.debug("No cache hit, forwarding request to backend")
54-
return await route_general_request(request, "/v1/chat/completions")
54+
return await route_general_request(
55+
request, "/v1/chat/completions", background_tasks
56+
)
5557

5658

5759
@main_router.post("/v1/completions")
58-
async def route_completion(request: Request):
59-
return await route_general_request(request, "/v1/completions")
60+
async def route_completion(request: Request, background_tasks: BackgroundTasks):
61+
return await route_general_request(request, "/v1/completions", background_tasks)
6062

6163

6264
@main_router.post("/v1/embeddings")
63-
async def route_embeddings(request: Request):
64-
return await route_general_request(request, "/v1/embeddings")
65+
async def route_embeddings(request: Request, background_tasks: BackgroundTasks):
66+
return await route_general_request(request, "/v1/embeddings", background_tasks)
6567

6668

6769
@main_router.post("/v1/rerank")
68-
async def route_v1_rerank(request: Request):
69-
return await route_general_request(request, "/v1/rerank")
70+
async def route_v1_rerank(request: Request, background_tasks: BackgroundTasks):
71+
return await route_general_request(request, "/v1/rerank", background_tasks)
7072

7173

7274
@main_router.post("/rerank")
73-
async def route_rerank(request: Request):
74-
return await route_general_request(request, "/rerank")
75+
async def route_rerank(request: Request, background_tasks: BackgroundTasks):
76+
return await route_general_request(request, "/rerank", background_tasks)
7577

7678

7779
@main_router.post("/v1/score")
78-
async def route_v1_score(request: Request):
79-
return await route_general_request(request, "/v1/score")
80+
async def route_v1_score(request: Request, background_tasks: BackgroundTasks):
81+
return await route_general_request(request, "/v1/score", background_tasks)
8082

8183

8284
@main_router.post("/score")
83-
async def route_score(request: Request):
84-
return await route_general_request(request, "/score")
85+
async def route_score(request: Request, background_tasks: BackgroundTasks):
86+
return await route_general_request(request, "/score", background_tasks)
8587

8688

8789
@main_router.get("/version")

src/vllm_router/services/callbacks_service/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import importlib
2+
3+
from fastapi import FastAPI
4+
5+
from vllm_router.log import init_logger
6+
7+
logger = init_logger(__name__)
8+
9+
10+
def initialize_custom_callbacks(callbacks_file_location: str, app: FastAPI):
11+
# Split the path by dots to separate module from instance
12+
parts = callbacks_file_location.split(".")
13+
14+
# The module path is all but the last part, and the instance_name is the last part
15+
module_name = ".".join(parts[:-1])
16+
instance_name = parts[-1]
17+
18+
module = importlib.import_module(module_name)
19+
app.state.callbacks = getattr(module, instance_name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from abc import abstractmethod
2+
3+
from fastapi import Request, Response
4+
5+
6+
class CustomCallbackHandler:
7+
"""
8+
Abstract class
9+
10+
Callbacks can be injected at multiple points within the request lifecycle.
11+
This can be used to validate the request or log the response.
12+
"""
13+
14+
@abstractmethod
15+
def pre_request(
16+
self, request: Request, request_body: bytes, request_json: any
17+
) -> Response | None:
18+
"""
19+
Receives the request object before it gets proxied.
20+
This can be used to validate the request or raise HTTP responses.
21+
22+
Args:
23+
request: The original request
24+
request_body: The request body as a byte array.
25+
request_json: The request body as a JSON object.
26+
27+
Returns:
28+
Either None or a Response Object which will end the request.
29+
"""
30+
return None
31+
32+
@abstractmethod
33+
def post_request(self, request: Request, response_content: bytes) -> None:
34+
"""
35+
Is executed as a background task, receives the request object and the complete response_content.
36+
This can be used to log the response or further process it.
37+
38+
Args:
39+
request: The original request
40+
response_content: The complete response content after the request has been completed as a byte array.
41+
"""
42+
pass

src/vllm_router/services/request_service/request.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import uuid
66

7-
from fastapi import Request
7+
from fastapi import BackgroundTasks, Request
88
from fastapi.responses import JSONResponse, StreamingResponse
99

1010
from vllm_router.log import init_logger
@@ -42,7 +42,13 @@
4242

4343

4444
async def process_request(
45-
request: Request, body, backend_url, request_id, endpoint, debug_request=None
45+
request: Request,
46+
body,
47+
backend_url,
48+
request_id,
49+
endpoint,
50+
background_tasks: BackgroundTasks,
51+
debug_request=None,
4652
):
4753
"""
4854
Process a request by sending it to the chosen backend.
@@ -78,7 +84,7 @@ async def process_request(
7884
pass
7985

8086
# For non-streaming requests, collect the full response to cache it properly
81-
full_response = bytearray() if not is_streaming else None
87+
full_response = bytearray()
8288

8389
async with request.app.state.httpx_client_wrapper().stream(
8490
method=request.method,
@@ -111,13 +117,19 @@ async def process_request(
111117
# Store in semantic cache if applicable
112118
# Use the full response for non-streaming requests, or the last chunk for streaming
113119
if request.app.state.semantic_cache_available:
114-
cache_chunk = bytes(full_response) if full_response is not None else chunk
120+
cache_chunk = bytes(full_response) if not is_streaming else chunk
115121
await store_in_semantic_cache(
116122
endpoint=endpoint, method=request.method, body=body, chunk=cache_chunk
117123
)
124+
if background_tasks and hasattr(request.app.state, "callbacks"):
125+
background_tasks.add_task(
126+
request.app.state.callbacks.post_request, request, full_response
127+
)
118128

119129

120-
async def route_general_request(request: Request, endpoint: str):
130+
async def route_general_request(
131+
request: Request, endpoint: str, background_tasks: BackgroundTasks
132+
):
121133
"""
122134
Route the incoming request to the backend server and stream the response back to the client.
123135
@@ -138,6 +150,14 @@ async def route_general_request(request: Request, endpoint: str):
138150
request_id = str(uuid.uuid4())
139151
request_body = await request.body()
140152
request_json = await request.json() # TODO (ApostaC): merge two awaits into one
153+
154+
if hasattr(request.app.state, "callbacks") and (
155+
response_overwrite := request.app.state.callbacks.pre_request(
156+
request, request_body, request_json
157+
)
158+
):
159+
return response_overwrite
160+
141161
requested_model = request_json.get("model", None)
142162
if requested_model is None:
143163
return JSONResponse(
@@ -185,6 +205,7 @@ async def route_general_request(request: Request, endpoint: str):
185205
request_body,
186206
server_url,
187207
request_id,
208+
background_tasks,
188209
endpoint=endpoint,
189210
)
190211
headers, status_code = await anext(stream_generator)

0 commit comments

Comments
 (0)