|
4 | 4 | import time
|
5 | 5 | from dataclasses import dataclass
|
6 | 6 | from pathlib import Path
|
7 |
| -from typing import Any, AsyncGenerator, Iterable, List, Optional, Sequence, Tuple |
| 7 | +from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Sequence, Tuple |
8 | 8 |
|
9 | 9 | import aiohttp
|
10 | 10 | import orjson
|
@@ -101,13 +101,24 @@ def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]:
|
101 | 101 | return json_payload, using_serialize_results_as_string
|
102 | 102 |
|
103 | 103 | @staticmethod
|
104 |
| - def get_response_payload(using_serialize_results_as_string: bool, response: Any): |
| 104 | + def get_response_payload( |
| 105 | + using_serialize_results_as_string: bool, |
| 106 | + forward_http_status_in_body: bool, |
| 107 | + response: Any, |
| 108 | + status_code: int, |
| 109 | + ) -> Any: |
105 | 110 | # Model Engine expects a JSON object with a "result" key.
|
| 111 | + |
| 112 | + response_payload: Dict[str, Any] = {} |
106 | 113 | if using_serialize_results_as_string:
|
107 | 114 | response_as_string: str = json.dumps(response)
|
108 |
| - return {"result": response_as_string} |
| 115 | + response_payload["result"] = response_as_string |
| 116 | + else: |
| 117 | + response_payload["result"] = response |
109 | 118 |
|
110 |
| - return {"result": response} |
| 119 | + if forward_http_status_in_body: |
| 120 | + response_payload["status_code"] = status_code |
| 121 | + return response_payload |
111 | 122 |
|
112 | 123 | @staticmethod
|
113 | 124 | def get_response_payload_stream(using_serialize_results_as_string: bool, response: str):
|
@@ -148,7 +159,12 @@ class Forwarder(ModelEngineSerializationMixin):
|
148 | 159 | model_engine_unwrap: bool
|
149 | 160 | serialize_results_as_string: bool
|
150 | 161 | wrap_response: bool
|
151 |
| - forward_http_status: bool |
| 162 | + # See celery_task_queue_gateway.py for why we should keep wrap_response as True |
| 163 | + # for async. tl;dr is we need to convey both the result as well as status code. |
| 164 | + forward_http_status: bool # Forwards http status in JSONResponse |
| 165 | + # Forwards http status in the response body. Only used if wrap_response is True |
| 166 | + # We do this to avoid having to put this data in any sync response and only do it for async responses |
| 167 | + forward_http_status_in_body: bool |
152 | 168 | post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None
|
153 | 169 |
|
154 | 170 | async def forward(self, json_payload: Any) -> Any:
|
@@ -191,7 +207,12 @@ async def forward(self, json_payload: Any) -> Any:
|
191 | 207 | )
|
192 | 208 |
|
193 | 209 | if self.wrap_response:
|
194 |
| - response = self.get_response_payload(using_serialize_results_as_string, response) |
| 210 | + response = self.get_response_payload( |
| 211 | + using_serialize_results_as_string, |
| 212 | + self.forward_http_status_in_body, |
| 213 | + response, |
| 214 | + response_raw.status, |
| 215 | + ) |
195 | 216 |
|
196 | 217 | if self.forward_http_status:
|
197 | 218 | return JSONResponse(content=response, status_code=response_raw.status)
|
@@ -233,7 +254,12 @@ def __call__(self, json_payload: Any) -> Any:
|
233 | 254 | )
|
234 | 255 |
|
235 | 256 | if self.wrap_response:
|
236 |
| - response = self.get_response_payload(using_serialize_results_as_string, response) |
| 257 | + response = self.get_response_payload( |
| 258 | + using_serialize_results_as_string, |
| 259 | + self.forward_http_status_in_body, |
| 260 | + response, |
| 261 | + response_raw.status_code, |
| 262 | + ) |
237 | 263 |
|
238 | 264 | if self.forward_http_status:
|
239 | 265 | return JSONResponse(content=response, status_code=response_raw.status_code)
|
@@ -263,6 +289,7 @@ class LoadForwarder:
|
263 | 289 | serialize_results_as_string: bool = True
|
264 | 290 | wrap_response: bool = True
|
265 | 291 | forward_http_status: bool = False
|
| 292 | + forward_http_status_in_body: bool = False |
266 | 293 |
|
267 | 294 | def load(self, resources: Optional[Path], cache: Any) -> Forwarder:
|
268 | 295 | if self.use_grpc:
|
@@ -370,6 +397,7 @@ def endpoint(route: str) -> str:
|
370 | 397 | post_inference_hooks_handler=handler,
|
371 | 398 | wrap_response=self.wrap_response,
|
372 | 399 | forward_http_status=self.forward_http_status,
|
| 400 | + forward_http_status_in_body=self.forward_http_status_in_body, |
373 | 401 | )
|
374 | 402 |
|
375 | 403 |
|
|
0 commit comments