forked from aws/sagemaker-distribution
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasync_handler.py
76 lines (54 loc) · 2.34 KB
/
async_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import absolute_import
import asyncio
import logging
from typing import AsyncIterator, Iterator
import tornado.web
from stream_handler import StreamHandler
from utils.environment import Environment
from utils.exception import AsyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
"""Handler mapped to the /invocations POST route.
This handler wraps the async handler retrieved from the inference script
and encapsulates it behind the post() method. The post() method is done
asynchronously.
"""
def initialize(self, handler: callable, environment: Environment):
"""Initializes the handler function and the serving environment."""
self._handler = handler
self._environment = environment
async def post(self):
"""POST method used to encapsulate and invoke the async handle method asynchronously"""
try:
response = await self._handler(self.request)
if isinstance(response, Iterator):
await self.stream(response)
elif isinstance(response, AsyncIterator):
await self.astream(response)
else:
self.write(response)
except Exception as e:
raise AsyncInvocationsException(e)
class PingHandler(tornado.web.RequestHandler):
"""Handler mapped to the /ping GET route.
Ping handler to monitor the health of the Tornados server.
"""
def get(self):
"""Simple GET method to assess the health of the server."""
self.write("")
async def handle(handler: callable, environment: Environment):
"""Serves the async handler function using Tornado.
Opens the /invocations and /ping routes used by a SageMaker Endpoint
for inference serving capabilities.
"""
logger.info("Starting inference server in asynchronous mode...")
app = tornado.web.Application(
[
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
]
)
app.listen(environment.port)
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()