Skip to content

Commit f3e466c

Browse files
Make the async task and sync task api actually return HTTP response codes from the user container (#676)
* comment * custom status code in echo server * catch InvalidRequestException from sync/streaming inference gateways * mark where we'd probably add the status_code * status code to dto * add to sync dto also * fix tests * task queue gateway returns status code if possible * forwarder stuff * black * try fixing integration tests * try fixing unit tests * missed spots * test case for cov * oops * fix * ... * ugh * eh just remove status code from result manually * revert integration test changes
1 parent 69875fc commit f3e466c

20 files changed

+161
-22
lines changed

integration_tests/test_endpoints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_sync_streaming_model_endpoint(capsys):
232232
for response in task_responses:
233233
assert (
234234
response.strip()
235-
== 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null}'
235+
== 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null,"status_code":200}'
236236
)
237237
finally:
238238
delete_model_endpoint(create_endpoint_request["name"], user)

model-engine/model_engine_server/api/tasks_v1.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def create_sync_inference_task(
116116
)
117117
except UpstreamServiceError as exc:
118118
return SyncEndpointPredictV1Response(
119-
status=TaskStatus.FAILURE, traceback=exc.content.decode()
119+
status=TaskStatus.FAILURE, traceback=exc.content.decode(), status_code=exc.status_code
120120
)
121121
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
122122
raise HTTPException(
@@ -133,6 +133,11 @@ async def create_sync_inference_task(
133133
status_code=408,
134134
detail="Request timed out.",
135135
) from exc
136+
except InvalidRequestException as exc:
137+
raise HTTPException(
138+
status_code=400,
139+
detail=f"Invalid request: {str(exc)}",
140+
) from exc
136141

137142

138143
@inference_task_router_v1.post("/streaming-tasks")
@@ -164,7 +169,9 @@ async def event_generator():
164169
iter(
165170
(
166171
SyncEndpointPredictV1Response(
167-
status=TaskStatus.FAILURE, traceback=exc.content.decode()
172+
status=TaskStatus.FAILURE,
173+
traceback=exc.content.decode(),
174+
status_code=exc.status_code,
168175
).json(),
169176
)
170177
)
@@ -179,3 +186,8 @@ async def event_generator():
179186
status_code=400,
180187
detail=f"Unsupported inference type: {str(exc)}",
181188
) from exc
189+
except InvalidRequestException as exc:
190+
raise HTTPException(
191+
status_code=400,
192+
detail=f"Invalid request: {str(exc)}",
193+
) from exc

model-engine/model_engine_server/common/dtos/tasks.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class GetAsyncTaskV1Response(BaseModel):
3434
status: TaskStatus
3535
result: Optional[ResponseSchema] = None
3636
traceback: Optional[str] = None
37+
status_code: Optional[int] = None
3738

3839

3940
class SyncEndpointPredictV1Response(BaseModel):
4041
status: TaskStatus
4142
result: Optional[Any] = None
4243
traceback: Optional[str] = None
44+
status_code: Optional[int] = None
4345

4446

4547
class EndpointPredictV1Request(BaseModel):

model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ forwarder:
2020
model_engine_unwrap: false
2121
serialize_results_as_string: false
2222
wrap_response: false
23+
forward_http_status_in_body: true

model-engine/model_engine_server/inference/configs/service--forwarder.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ forwarder:
1818
batch_route: null
1919
model_engine_unwrap: true
2020
serialize_results_as_string: true
21+
forward_http_status_in_body: true

model-engine/model_engine_server/inference/forwarding/celery_forwarder.py

+3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
141141
logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}")
142142
try:
143143
monitoring_metrics_gateway.emit_async_task_received_metric(queue_name)
144+
# Don't fail the celery task even if there's a status code
145+
# (otherwise we can't really control what gets put in the result attribute)
146+
# in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status)
144147
result = forwarder(payload)
145148
request_duration = datetime.now() - arrival_timestamp
146149
if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS):

model-engine/model_engine_server/inference/forwarding/echo_server.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@ async def predict(request: Request):
2727
print("Received request", dictionary, flush=True)
2828
if "delay" in dictionary:
2929
await asyncio.sleep(dictionary["delay"])
30+
if "status_code" in dictionary:
31+
return JSONResponse(content=dictionary, status_code=dictionary["status_code"])
3032
return dictionary
3133

3234

3335
@app.post("/predict500")
3436
async def predict500(request: Request):
35-
response = JSONResponse(content=await request.json(), status_code=500)
37+
dictionary = await request.json()
38+
if "delay" in dictionary:
39+
await asyncio.sleep(dictionary["delay"])
40+
if "status_code" in dictionary:
41+
status_code = dictionary["status_code"]
42+
else:
43+
status_code = 500
44+
response = JSONResponse(content=dictionary, status_code=status_code)
3645
return response
3746

3847

model-engine/model_engine_server/inference/forwarding/forwarding.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from dataclasses import dataclass
66
from pathlib import Path
7-
from typing import Any, AsyncGenerator, Iterable, List, Optional, Sequence, Tuple
7+
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Sequence, Tuple
88

99
import aiohttp
1010
import orjson
@@ -101,13 +101,24 @@ def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]:
101101
return json_payload, using_serialize_results_as_string
102102

103103
@staticmethod
104-
def get_response_payload(using_serialize_results_as_string: bool, response: Any):
104+
def get_response_payload(
105+
using_serialize_results_as_string: bool,
106+
forward_http_status_in_body: bool,
107+
response: Any,
108+
status_code: int,
109+
) -> Any:
105110
# Model Engine expects a JSON object with a "result" key.
111+
112+
response_payload: Dict[str, Any] = {}
106113
if using_serialize_results_as_string:
107114
response_as_string: str = json.dumps(response)
108-
return {"result": response_as_string}
115+
response_payload["result"] = response_as_string
116+
else:
117+
response_payload["result"] = response
109118

110-
return {"result": response}
119+
if forward_http_status_in_body:
120+
response_payload["status_code"] = status_code
121+
return response_payload
111122

112123
@staticmethod
113124
def get_response_payload_stream(using_serialize_results_as_string: bool, response: str):
@@ -148,7 +159,12 @@ class Forwarder(ModelEngineSerializationMixin):
148159
model_engine_unwrap: bool
149160
serialize_results_as_string: bool
150161
wrap_response: bool
151-
forward_http_status: bool
162+
# See celery_task_queue_gateway.py for why we should keep wrap_response as True
163+
# for async. tl;dr is we need to convey both the result as well as status code.
164+
forward_http_status: bool # Forwards http status in JSONResponse
165+
# Forwards http status in the response body. Only used if wrap_response is True
166+
# We do this to avoid having to put this data in any sync response and only do it for async responses
167+
forward_http_status_in_body: bool
152168
post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None
153169

154170
async def forward(self, json_payload: Any) -> Any:
@@ -191,7 +207,12 @@ async def forward(self, json_payload: Any) -> Any:
191207
)
192208

193209
if self.wrap_response:
194-
response = self.get_response_payload(using_serialize_results_as_string, response)
210+
response = self.get_response_payload(
211+
using_serialize_results_as_string,
212+
self.forward_http_status_in_body,
213+
response,
214+
response_raw.status,
215+
)
195216

196217
if self.forward_http_status:
197218
return JSONResponse(content=response, status_code=response_raw.status)
@@ -233,7 +254,12 @@ def __call__(self, json_payload: Any) -> Any:
233254
)
234255

235256
if self.wrap_response:
236-
response = self.get_response_payload(using_serialize_results_as_string, response)
257+
response = self.get_response_payload(
258+
using_serialize_results_as_string,
259+
self.forward_http_status_in_body,
260+
response,
261+
response_raw.status_code,
262+
)
237263

238264
if self.forward_http_status:
239265
return JSONResponse(content=response, status_code=response_raw.status_code)
@@ -263,6 +289,7 @@ class LoadForwarder:
263289
serialize_results_as_string: bool = True
264290
wrap_response: bool = True
265291
forward_http_status: bool = False
292+
forward_http_status_in_body: bool = False
266293

267294
def load(self, resources: Optional[Path], cache: Any) -> Forwarder:
268295
if self.use_grpc:
@@ -370,6 +397,7 @@ def endpoint(route: str) -> str:
370397
post_inference_hooks_handler=handler,
371398
wrap_response=self.wrap_response,
372399
forward_http_status=self.forward_http_status,
400+
forward_http_status_in_body=self.forward_http_status_in_body,
373401
)
374402

375403

model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def send_task(
6868
kwargs: Optional[Dict[str, Any]] = None,
6969
expires: Optional[int] = None,
7070
) -> CreateAsyncTaskV1Response:
71+
# Used for both endpoint infra creation and async tasks
7172
celery_dest = self._get_celery_dest()
7273

7374
try:
@@ -84,6 +85,7 @@ def send_task(
8485
return CreateAsyncTaskV1Response(task_id=res.id)
8586

8687
def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
88+
# Only used for async tasks
8789
celery_dest = self._get_celery_dest()
8890
res = celery_dest.AsyncResult(task_id)
8991
response_state = res.state
@@ -92,15 +94,27 @@ def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
9294
# result_dict = (
9395
# response_result if type(response_result) is dict else {"result": response_result}
9496
# )
97+
status_code = None
98+
result = res.result
99+
if type(result) is dict and "status_code" in result:
100+
# Filter out status code from result if it was added by the forwarder
101+
# This is admittedly kinda hacky and would technically introduce an edge case
102+
# if we ever decide not to have async tasks wrap response.
103+
status_code = result["status_code"]
104+
del result["status_code"]
95105
return GetAsyncTaskV1Response(
96-
task_id=task_id, status=TaskStatus.SUCCESS, result=res.result
106+
task_id=task_id,
107+
status=TaskStatus.SUCCESS,
108+
result=result,
109+
status_code=status_code,
97110
)
98111

99112
elif response_state == "FAILURE":
100113
return GetAsyncTaskV1Response(
101114
task_id=task_id,
102115
status=TaskStatus.FAILURE,
103116
traceback=res.traceback,
117+
status_code=None, # probably
104118
)
105119
elif response_state == "RETRY":
106120
# Backwards compatibility, otherwise we'd need to add "RETRY" to the clients

model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ async def streaming_predict(
235235
endpoint_name=endpoint_name or topic,
236236
)
237237
async for item in response:
238-
yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item)
238+
yield SyncEndpointPredictV1Response(
239+
status=TaskStatus.SUCCESS, result=item, status_code=200
240+
)
239241
except UpstreamServiceError as exc:
240242
logger.error(f"Service error on streaming task: {exc.content!r}")
241243

@@ -258,4 +260,5 @@ async def streaming_predict(
258260
yield SyncEndpointPredictV1Response(
259261
status=TaskStatus.FAILURE,
260262
traceback=result_traceback,
263+
status_code=exc.status_code,
261264
)

model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,17 @@ async def predict(
238238
return SyncEndpointPredictV1Response(
239239
status=TaskStatus.FAILURE,
240240
traceback=result_traceback,
241+
status_code=exc.status_code,
241242
)
242243

243244
except Exception as e:
244245
logger.error(f"Failed to parse error: {e}")
245246
return SyncEndpointPredictV1Response(
246-
status=TaskStatus.FAILURE, traceback=exc.content.decode()
247+
status=TaskStatus.FAILURE,
248+
traceback=exc.content.decode(),
249+
status_code=exc.status_code,
247250
)
248251

249-
return SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=response)
252+
return SyncEndpointPredictV1Response(
253+
status=TaskStatus.SUCCESS, result=response, status_code=200
254+
)

model-engine/tests/unit/api/test_llms.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def test_completion_sync_success(
116116
}"""
117117
},
118118
traceback=None,
119+
status_code=200,
119120
),
120121
)
121122
response_1 = client.post(

model-engine/tests/unit/api/test_tasks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ async def test_create_streaming_task_success(
410410
count = 0
411411
async for message in response.aiter_bytes():
412412
assert (
413-
message == b'data: {"status":"SUCCESS","result":null,"traceback":null}\r\n\r\n'
413+
message
414+
== b'data: {"status":"SUCCESS","result":null,"traceback":null,"status_code":200}\r\n\r\n'
414415
)
415416
count += 1
416417
assert count == 1

model-engine/tests/unit/conftest.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1028,18 +1028,17 @@ def get_task_args(self, task_id: str):
10281028

10291029
def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
10301030
result = None
1031+
status_code = None
10311032
if task_id in self.queue:
10321033
status = TaskStatus.PENDING
10331034
elif task_id in self.completed:
10341035
status = TaskStatus.SUCCESS
10351036
result = self.completed[task_id]
1037+
status_code = 200
10361038
else:
10371039
status = TaskStatus.UNDEFINED
10381040
return GetAsyncTaskV1Response(
1039-
task_id=task_id,
1040-
status=status,
1041-
result=result,
1042-
traceback=None,
1041+
task_id=task_id, status=status, result=result, traceback=None, status_code=status_code
10431042
)
10441043

10451044
def clear_queue(self, queue_name: str) -> bool:
@@ -1537,6 +1536,7 @@ def __init__(self):
15371536
status=TaskStatus.SUCCESS,
15381537
result=None,
15391538
traceback=None,
1539+
status_code=200,
15401540
)
15411541
]
15421542

@@ -1561,6 +1561,7 @@ def __init__(self, fake_sync_inference_content=None):
15611561
status=TaskStatus.SUCCESS,
15621562
result=None,
15631563
traceback=None,
1564+
status_code=200,
15641565
)
15651566
else:
15661567
self.response = fake_sync_inference_content
@@ -1662,6 +1663,7 @@ def get_task(self, task_id: str) -> GetAsyncTaskV1Response:
16621663
status=TaskStatus.SUCCESS,
16631664
result=None,
16641665
traceback=None,
1666+
status_code=200,
16651667
)
16661668

16671669
def get_last_request(self):

0 commit comments

Comments
 (0)