Skip to content

Commit 692d625

Browse files
committed
Enable inference serving capabilities on sagemaker endpoint using tornado
1 parent 6fdda45 commit 692d625

File tree

13 files changed

+504
-1
lines changed

13 files changed

+504
-1
lines changed

template/v3/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \
191191
&& ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \
192192
&& rm -rf ${HOME_DIR}/oss_compliance*
193193

194-
ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH"
194+
ENV PATH="/etc/sagemaker-inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH"
195195
WORKDIR "/home/${NB_USER}"
196196
ENV SHELL=/bin/bash
197197
ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import absolute_import
2+
3+
import utils.logger
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/bash
2+
python /etc/sagemaker-inference-server/serve.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import absolute_import
2+
3+
"""
4+
TODO: when adding support for more serving frameworks, move the below logic into a condition statement.
5+
We also need to define the right environment variable for signify what serving framework to use.
6+
7+
Ex.
8+
9+
inference_server = None
10+
serving_framework = os.getenv("SAGEMAKER_INFERENCE_FRAMEWORK", None)
11+
12+
if serving_framework == "FastAPI":
13+
inference_server = FastApiServer()
14+
elif serving_framework == "Flask":
15+
inference_server = FlaskServer()
16+
else:
17+
inference_server = TornadoServer()
18+
19+
inference_server.serve()
20+
21+
"""
22+
from tornado_server.server import TornadoServer
23+
24+
inference_server = TornadoServer()
25+
inference_server.serve()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import absolute_import
2+
3+
import pathlib
4+
import sys
5+
6+
# make the utils modules accessible to modules from within the tornado_server folder
7+
utils_path = pathlib.Path(__file__).parent.parent / "utils"
8+
sys.path.insert(0, str(utils_path.resolve()))
9+
10+
# make the tornado_server modules accessible to each other
11+
tornado_module_path = pathlib.Path(__file__).parent
12+
sys.path.insert(0, str(tornado_module_path.resolve()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import logging
5+
from typing import AsyncIterator, Iterator
6+
7+
import tornado.web
8+
from stream_handler import StreamHandler
9+
10+
from utils.environment import Environment
11+
from utils.exception import AsyncInvocationsException
12+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
13+
14+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
15+
16+
17+
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
18+
"""Handler mapped to the /invocations POST route.
19+
20+
This handler wraps the async handler retrieved from the inference script
21+
and encapsulates it behind the post() method. The post() method is done
22+
asynchronously.
23+
"""
24+
25+
def initialize(self, handler: callable, environment: Environment):
26+
"""Initializes the handler function and the serving environment."""
27+
28+
self._handler = handler
29+
self._environment = environment
30+
31+
async def post(self):
32+
"""POST method used to encapsulate and invoke the async handle method asynchronously"""
33+
34+
try:
35+
response = await self._handler(self.request)
36+
37+
if isinstance(response, Iterator):
38+
await self.stream(response)
39+
elif isinstance(response, AsyncIterator):
40+
await self.astream(response)
41+
else:
42+
self.write(response)
43+
except Exception as e:
44+
raise AsyncInvocationsException(e)
45+
46+
47+
class PingHandler(tornado.web.RequestHandler):
48+
"""Handler mapped to the /ping GET route.
49+
50+
Ping handler to monitor the health of the Tornados server.
51+
"""
52+
53+
def get(self):
54+
"""Simple GET method to assess the health of the server."""
55+
56+
self.write("")
57+
58+
59+
async def handle(handler: callable, environment: Environment):
60+
"""Serves the async handler function using Tornado.
61+
62+
Opens the /invocations and /ping routes used by a SageMaker Endpoint
63+
for inference serving capabilities.
64+
"""
65+
66+
logger.info("Starting inference server in asynchronous mode...")
67+
68+
app = tornado.web.Application(
69+
[
70+
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
71+
(r"/ping", PingHandler),
72+
]
73+
)
74+
app.listen(environment.port)
75+
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
76+
await asyncio.Event().wait()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import importlib
5+
import logging
6+
import subprocess
7+
import sys
8+
from pathlib import Path
9+
10+
from utils.environment import Environment
11+
from utils.exception import (
12+
InferenceCodeLoadException,
13+
RequirementsInstallException,
14+
ServerStartException,
15+
)
16+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
17+
18+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
19+
20+
21+
class TornadoServer:
22+
"""Holds serving logic using the Tornado framework.
23+
24+
The serve.py script will invoke TornadoServer.serve() to start the serving process.
25+
The TornadoServer will install the runtime requirements specified through a requirements file.
26+
It will then load an handler function within an inference script and then front it will an /invocations
27+
route using the Tornado framework.
28+
"""
29+
30+
def __init__(self):
31+
"""Initialize the serving behaviors.
32+
33+
Defines the serving behavior through Environment() and locate where
34+
the inference code is contained.
35+
"""
36+
37+
self._environment = Environment()
38+
logger.setLevel(int(self._environment.logging_level))
39+
logger.debug(f"Environment: {str(self._environment)}")
40+
41+
self._path_to_inference_code = (
42+
Path(self._environment.base_directory).joinpath(self._environment.code_directory)
43+
if self._environment.code_directory
44+
else Path(self._environment.base_directory)
45+
)
46+
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`")
47+
48+
def initialize(self):
49+
"""Initialize the serving artifacts and dependencies.
50+
51+
Install the runtime requirements and then locate the handler function from
52+
the inference script.
53+
"""
54+
55+
logger.info("Initializing inference server...")
56+
self._install_runtime_requirements()
57+
self._handler = self._load_inference_handler()
58+
59+
def serve(self):
60+
"""Orchestrate the initialization and server startup behavior.
61+
62+
Call the initalize() method, determine the right Tornado serving behavior (async or sync),
63+
and then start the Tornado server through asyncio
64+
"""
65+
66+
logger.info("Serving inference requests using Tornado...")
67+
self.initialize()
68+
69+
if asyncio.iscoroutinefunction(self._handler):
70+
import async_handler as inference_handler
71+
else:
72+
import sync_handler as inference_handler
73+
74+
try:
75+
asyncio.run(inference_handler.handle(self._handler, self._environment))
76+
except Exception as e:
77+
raise ServerStartException(e)
78+
79+
def _install_runtime_requirements(self):
80+
"""Install the runtime requirements."""
81+
82+
logger.info("Installing runtime requirements...")
83+
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements)
84+
if requirements_txt.is_file():
85+
try:
86+
subprocess.check_call(["micromamba", "install", "--yes", "--file", str(requirements_txt)])
87+
except Exception as e:
88+
logger.error("Failed to install requirements using `micromamba install`. Falling back to `pip install`...")
89+
try:
90+
subprocess.check_call(["pip", "install", "-r", str(requirements_txt)])
91+
except Exception as e:
92+
raise RequirementsInstallException(e)
93+
else:
94+
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`")
95+
96+
def _load_inference_handler(self) -> callable:
97+
"""Load the handler function from the inference script."""
98+
99+
logger.info("Loading inference handler...")
100+
inference_module_name, handle_name = self._environment.code.split(".")
101+
if inference_module_name and handle_name:
102+
inference_module_file = f"{inference_module_name}.py"
103+
module_spec = importlib.util.spec_from_file_location(
104+
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file))
105+
)
106+
if module_spec:
107+
sys.path.insert(0, str(self._path_to_inference_code.resolve()))
108+
module = importlib.util.module_from_spec(module_spec)
109+
module_spec.loader.exec_module(module)
110+
111+
if hasattr(module, handle_name):
112+
handler = getattr(module, handle_name)
113+
else:
114+
raise InferenceCodeLoadException(
115+
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`"
116+
)
117+
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`")
118+
return handler
119+
else:
120+
raise InferenceCodeLoadException(
121+
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`"
122+
)
123+
raise InferenceCodeLoadException(
124+
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}"
125+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import absolute_import
2+
3+
import logging
4+
from typing import AsyncIterator, Iterator
5+
6+
from tornado.ioloop import IOLoop
7+
8+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
9+
10+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
11+
12+
13+
class StreamHandler:
14+
"""Mixin that enables async and sync streaming capabilities to the async and sync handlers
15+
16+
stream() runs a provided iterator/generator fn in an async manner.
17+
astream() runs a provided async iterator/generator fn in an async manner.
18+
"""
19+
20+
async def stream(self, iterator: Iterator):
21+
"""Streams the response from a sync response iterator
22+
23+
A sync iterator must be manually iterated through asynchronously.
24+
In a loop, iterate through each next(iterator) call in an async execution.
25+
"""
26+
27+
self._set_stream_headers()
28+
29+
while True:
30+
try:
31+
chunk = await IOLoop.current().run_in_executor(None, next, iterator)
32+
# Some iterators do not throw a StopIteration upon exhaustion.
33+
# Instead, they return an empty response. Account for this case.
34+
if not chunk:
35+
raise StopIteration()
36+
37+
self.write(chunk)
38+
await self.flush()
39+
except StopIteration:
40+
break
41+
except Exception as e:
42+
logger.error("Unexpected exception occurred when streaming response...")
43+
break
44+
45+
async def astream(self, aiterator: AsyncIterator):
46+
"""Streams the response from an async response iterator"""
47+
48+
self._set_stream_headers()
49+
50+
async for chunk in aiterator:
51+
self.write(chunk)
52+
await self.flush()
53+
54+
def _set_stream_headers(self):
55+
"""Set the headers in preparation for the streamed response"""
56+
57+
self.set_header("Content-Type", "text/event-stream")
58+
self.set_header("Cache-Control", "no-cache")
59+
self.set_header("Connection", "keep-alive")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import logging
5+
from typing import AsyncIterator, Iterator
6+
7+
import tornado.web
8+
from stream_handler import StreamHandler
9+
from tornado.ioloop import IOLoop
10+
11+
from utils.environment import Environment
12+
from utils.exception import SyncInvocationsException
13+
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER
14+
15+
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)
16+
17+
18+
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
19+
"""Handler mapped to the /invocations POST route.
20+
21+
This handler wraps the sync handler retrieved from the inference script
22+
and encapsulates it behind the post() method. The post() method is done
23+
asynchronously.
24+
"""
25+
26+
def initialize(self, handler: callable, environment: Environment):
27+
"""Initializes the handler function and the serving environment."""
28+
29+
self._handler = handler
30+
self._environment = environment
31+
32+
async def post(self):
33+
"""POST method used to encapsulate and invoke the sync handle method asynchronously"""
34+
35+
try:
36+
response = await IOLoop.current().run_in_executor(None, self._handler, self.request)
37+
38+
if isinstance(response, Iterator):
39+
await self.stream(response)
40+
elif isinstance(response, AsyncIterator):
41+
await self.astream(response)
42+
else:
43+
self.write(response)
44+
except Exception as e:
45+
raise SyncInvocationsException(e)
46+
47+
48+
class PingHandler(tornado.web.RequestHandler):
49+
"""Handler mapped to the /ping GET route.
50+
51+
Ping handler to monitor the health of the Tornados server.
52+
"""
53+
54+
def get(self):
55+
"""Simple GET method to assess the health of the server."""
56+
57+
self.write("")
58+
59+
60+
async def handle(handler: callable, environment: Environment):
61+
"""Serves the sync handler function using Tornado.
62+
63+
Opens the /invocations and /ping routes used by a SageMaker Endpoint
64+
for inference serving capabilities.
65+
"""
66+
67+
logger.info("Starting inference server in synchronous mode...")
68+
69+
app = tornado.web.Application(
70+
[
71+
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
72+
(r"/ping", PingHandler),
73+
]
74+
)
75+
app.listen(environment.port)
76+
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
77+
await asyncio.Event().wait()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from __future__ import absolute_import

0 commit comments

Comments
 (0)