Skip to content

Commit 4153ebb

Browse files
committed
Enable inference serving capabilities on sagemaker endpoint using tornado
1 parent 33b6986 commit 4153ebb

File tree

13 files changed

+479
-1
lines changed

13 files changed

+479
-1
lines changed

template/v3/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \
190190
&& ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \
191191
&& rm -rf ${HOME_DIR}/oss_compliance*
192192

193-
ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH"
193+
ENV PATH="/etc/sagemaker-inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH"
194194
WORKDIR "/home/${NB_USER}"
195195
ENV SHELL=/bin/bash
196196
ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import absolute_import
2+
3+
import utils.logger
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/bash
2+
python /etc/sagemaker-inference-server/serve.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from __future__ import absolute_import
2+
3+
from tornado_server.server import TornadoServer
4+
5+
inference_server = TornadoServer()
6+
inference_server.serve()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import absolute_import
2+
3+
import pathlib
4+
import sys
5+
6+
# make the utils modules accessible to modules from within the tornado_server folder
7+
utils_path = pathlib.Path(__file__).parent.parent / "utils"
8+
sys.path.insert(0, str(utils_path.resolve()))
9+
10+
# make the tornado_server modules accessible to each other
11+
tornado_module_path = pathlib.Path(__file__).parent
12+
sys.path.insert(0, str(tornado_module_path.resolve()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import logging
5+
from typing import AsyncGenerator, Generator
6+
7+
import tornado.web
8+
from stream_handler import StreamHandler
9+
10+
from utils.environment import Environment
11+
from utils.exception import AsyncInvocationsException
12+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
13+
14+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
15+
16+
17+
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
18+
"""Handler mapped to the /invocations POST route.
19+
20+
This handler wraps the async handler retrieved from the inference script
21+
and encapsulates it behind the post() method. The post() method is done
22+
asynchronously.
23+
"""
24+
25+
def initialize(self, handler: callable, environment: Environment):
26+
"""Initializes the handler function and the serving environment."""
27+
28+
self._handler = handler
29+
self._environment = environment
30+
31+
async def post(self):
32+
"""POST method used to encapsulate and invoke the async handle method asynchronously"""
33+
34+
try:
35+
response = await self._handler(self.request)
36+
37+
if isinstance(response, Generator):
38+
await self.stream(response)
39+
elif isinstance(response, AsyncGenerator):
40+
await self.astream(response)
41+
else:
42+
self.write(response)
43+
except Exception as e:
44+
raise AsyncInvocationsException(e)
45+
46+
47+
class PingHandler(tornado.web.RequestHandler):
48+
"""Handler mapped to the /ping GET route.
49+
50+
Ping handler to monitor the health of the Tornados server.
51+
"""
52+
53+
def get(self):
54+
"""Simple GET method to assess the health of the server."""
55+
56+
self.write("")
57+
58+
59+
async def handle(handler: callable, environment: Environment):
60+
"""Serves the async handler function using Tornado.
61+
62+
Opens the /invocations and /ping routes used by a SageMaker Endpoint
63+
for inference serving capabilities.
64+
"""
65+
66+
logger.info("Starting inference server in asynchronous mode...")
67+
68+
app = tornado.web.Application(
69+
[
70+
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
71+
(r"/ping", PingHandler),
72+
]
73+
)
74+
app.listen(environment.port)
75+
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
76+
await asyncio.Event().wait()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import importlib
5+
import logging
6+
import subprocess
7+
import sys
8+
from pathlib import Path
9+
10+
from utils.environment import Environment
11+
from utils.exception import (
12+
InferenceCodeLoadException,
13+
RequirementsInstallException,
14+
ServerStartException,
15+
)
16+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
17+
18+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
19+
20+
21+
class TornadoServer:
22+
"""Holds serving logic using the Tornado framework.
23+
24+
The serve.py script will invoke TornadoServer.serve() to start the serving process.
25+
The TornadoServer will install the runtime requirements specified through a requirements file.
26+
It will then load an handler function within an inference script and then front it will an /invocations
27+
route using the Tornado framework.
28+
"""
29+
30+
def __init__(self):
31+
"""Initialize the serving behaviors.
32+
33+
Defines the serving behavior through Environment() and locate where
34+
the inference code is contained.
35+
"""
36+
37+
self._environment = Environment()
38+
logger.setLevel(self._environment.logging_level)
39+
logger.debug(f"Environment: {str(self._environment)}")
40+
41+
self._path_to_inference_code = (
42+
Path(self._environment.base_directory).joinpath(self._environment.code_directory)
43+
if self._environment.code_directory
44+
else Path(self._environment.base_directory)
45+
)
46+
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`")
47+
48+
self._handler = None
49+
50+
def initialize(self):
51+
"""Initialize the serving artifacts and dependencies.
52+
53+
Install the runtime requirements and then locate the handler function from
54+
the inference script.
55+
"""
56+
57+
logger.info("Initializing inference server...")
58+
self._install_runtime_requirements()
59+
self._handler = self._load_inference_handler()
60+
61+
def serve(self):
62+
"""Orchestrate the initialization and server startup behavior.
63+
64+
Call the initalize() method, determine the right Tornado serving behavior (async or sync),
65+
and then start the Tornado server through asyncio
66+
"""
67+
68+
logger.info("Serving inference requests using Tornado...")
69+
self.initialize()
70+
71+
if asyncio.iscoroutinefunction(self._handler):
72+
import async_handler as inference_handler
73+
else:
74+
import sync_handler as inference_handler
75+
76+
try:
77+
asyncio.run(inference_handler.handle(self._handler, self._environment))
78+
except Exception as e:
79+
raise ServerStartException(e)
80+
81+
def _install_runtime_requirements(self):
82+
"""Install the runtime requirements."""
83+
84+
logger.info("Installing runtime requirements...")
85+
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements)
86+
if requirements_txt.is_file():
87+
try:
88+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)])
89+
except Exception as e:
90+
raise RequirementsInstallException(e)
91+
else:
92+
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`")
93+
94+
def _load_inference_handler(self) -> callable:
95+
"""Load the handler function from the inference script."""
96+
97+
logger.info("Loading inference handler...")
98+
inference_module_name, handle_name = self._environment.code.split(".")
99+
if inference_module_name and handle_name:
100+
inference_module_file = f"{inference_module_name}.py"
101+
module_spec = importlib.util.spec_from_file_location(
102+
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file))
103+
)
104+
if module_spec:
105+
sys.path.insert(0, str(self._path_to_inference_code.resolve()))
106+
module = importlib.util.module_from_spec(module_spec)
107+
module_spec.loader.exec_module(module)
108+
109+
if hasattr(module, handle_name):
110+
handler = getattr(module, handle_name)
111+
else:
112+
logger.info(dir(inference_module))
113+
raise InferenceCodeLoadException(
114+
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`"
115+
)
116+
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`")
117+
return handler
118+
else:
119+
raise InferenceCodeLoadException(
120+
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`"
121+
)
122+
raise InferenceCodeLoadException(
123+
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}"
124+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import absolute_import
2+
3+
from typing import AsyncGenerator, Generator
4+
5+
from tornado.ioloop import IOLoop
6+
7+
8+
class StreamHandler:
9+
"""Mixin that enables async and sync streaming capabilities to the async and sync handlers
10+
11+
stream() runs a provided generator fn in an async manner.
12+
astream() runs a provided async generator fn in an async manner.
13+
"""
14+
15+
async def stream(self, generator: Generator):
16+
"""Streams the response from a sync response generator
17+
18+
A sync generator must be manually iterated through asynchronously.
19+
In a loop, iterate through each next(generator) call in an async execution.
20+
"""
21+
22+
self._set_stream_headers()
23+
24+
while True:
25+
try:
26+
chunk = await IOLoop.current().run_in_executor(None, next, generator)
27+
# Some generators do not throw a StopIteration upon exhaustion.
28+
# Instead, they return an empty response. Account for this case.
29+
if not chunk:
30+
raise StopIteration()
31+
32+
self.write(chunk)
33+
await self.flush()
34+
except StopIteration:
35+
break
36+
except Exception as e:
37+
logger.error("Unexpected exception occurred when streaming response...")
38+
break
39+
40+
async def astream(self, agenerator: AsyncGenerator):
41+
"""Streams the response from an async response generator"""
42+
43+
self._set_stream_headers()
44+
45+
async for chunk in agenerator:
46+
self.write(chunk)
47+
await self.flush()
48+
49+
def _set_stream_headers(self):
50+
"""Set the headers in preparation for the streamed response"""
51+
52+
self.set_header("Content-Type", "text/event-stream")
53+
self.set_header("Cache-Control", "no-cache")
54+
self.set_header("Connection", "keep-alive")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import logging
5+
from typing import AsyncGenerator, Generator
6+
7+
import tornado.web
8+
from stream_handler import StreamHandler
9+
from tornado.ioloop import IOLoop
10+
11+
from utils.environment import Environment
12+
from utils.exception import SyncInvocationsException
13+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
14+
15+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
16+
17+
18+
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
19+
"""Handler mapped to the /invocations POST route.
20+
21+
This handler wraps the sync handler retrieved from the inference script
22+
and encapsulates it behind the post() method. The post() method is done
23+
asynchronously.
24+
"""
25+
26+
def initialize(self, handler: callable, environment: Environment):
27+
"""Initializes the handler function and the serving environment."""
28+
29+
self._handler = handler
30+
self._environment = environment
31+
32+
async def post(self):
33+
"""POST method used to encapsulate and invoke the sync handle method asynchronously"""
34+
35+
try:
36+
response = await IOLoop.current().run_in_executor(None, self._handler, self.request)
37+
38+
if isinstance(response, Generator):
39+
await self.stream(response)
40+
elif isinstance(response, AsyncGenerator):
41+
await self.astream(response)
42+
else:
43+
self.write(response)
44+
except Exception as e:
45+
raise SyncInvocationsException(e)
46+
47+
48+
class PingHandler(tornado.web.RequestHandler):
49+
"""Handler mapped to the /ping GET route.
50+
51+
Ping handler to monitor the health of the Tornados server.
52+
"""
53+
54+
def get(self):
55+
"""Simple GET method to assess the health of the server."""
56+
57+
self.write("")
58+
59+
60+
async def handle(handler: callable, environment: Environment):
61+
"""Serves the sync handler function using Tornado.
62+
63+
Opens the /invocations and /ping routes used by a SageMaker Endpoint
64+
for inference serving capabilities.
65+
"""
66+
67+
logger.info("Starting inference server in synchronous mode...")
68+
69+
app = tornado.web.Application(
70+
[
71+
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
72+
(r"/ping", PingHandler),
73+
]
74+
)
75+
app.listen(environment.port)
76+
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
77+
await asyncio.Event().wait()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from __future__ import absolute_import

0 commit comments

Comments
 (0)