From 94193445994b8ffcd62824a400bef842b7dd5dd5 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Thu, 9 May 2024 12:14:08 -0700 Subject: [PATCH] Add support for gRPC max connection age configuration (#3121) * Add support for gRPC max connection age configuration * Add separate gRPC connection age configurations for Inference and Management endpoints * Fix configuration parsing and default value * Link docs for gRPC max connection age configuration * fix spellcheck --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com> --- docs/configuration.md | 12 +++++- .../java/org/pytorch/serve/ModelServer.java | 7 ++++ .../org/pytorch/serve/util/ConfigManager.java | 40 +++++++++++++++++++ ts_scripts/spellcheck_conf/wordlist.txt | 1 + 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 9f2afb1b30..22e3a8bf06 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -93,7 +93,7 @@ inference_address=https://127.0.0.1:8443 inference_address=https://172.16.1.10:8080 ``` -### Configure TorchServe gRPC listening addresses and ports +### Configure TorchServe gRPC listening addresses, ports and max connection age The inference gRPC API is listening on port 7070, and the management gRPC API is listening on port 7071 on localhost by default. To configure different addresses use following properties @@ -106,7 +106,15 @@ To configure different ports use following properties * `grpc_inference_port`: Inference gRPC API binding port. Default: 7070 * `grpc_management_port`: management gRPC API binding port. Default: 7071 -Here are a couple of examples: +To configure [max connection age](https://grpc.github.io/grpc-java/javadoc/io/grpc/netty/NettyServerBuilder.html#maxConnectionAge(long,java.util.concurrent.TimeUnit)) (milliseconds) + +* `grpc_inference_max_connection_age_ms`: Inference gRPC max connection age. Default: Infinite +* `grpc_management_max_connection_age_ms`: Management gRPC max connection age. Default: Infinite + +To configure [max connection age grace](https://grpc.github.io/grpc-java/javadoc/io/grpc/netty/NettyServerBuilder.html#maxConnectionAgeGrace(long,java.util.concurrent.TimeUnit)) (milliseconds) + +* `grpc_inference_max_connection_age_grace_ms`: Inference gRPC max connection age grace. Default: Infinite +* `grpc_management_max_connection_age_grace_ms`: Management gRPC max connection age grace. Default: Infinite ### Enable SSL diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index e05eb62943..1fd8189609 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -27,6 +27,7 @@ import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; @@ -452,6 +453,12 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException { new InetSocketAddress( configManager.getGRPCAddress(connectorType), configManager.getGRPCPort(connectorType))) + .maxConnectionAge( + configManager.getGRPCMaxConnectionAge(connectorType), + TimeUnit.MILLISECONDS) + .maxConnectionAgeGrace( + configManager.getGRPCMaxConnectionAgeGrace(connectorType), + TimeUnit.MILLISECONDS) .maxInboundMessageSize(configManager.getMaxRequestSize()) .addService( ServerInterceptors.intercept( diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index e215b4d87f..829d57c530 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -105,6 +105,14 @@ public final class ConfigManager { private static final String TS_GRPC_MANAGEMENT_ADDRESS = "grpc_management_address"; private static final String TS_GRPC_INFERENCE_PORT = "grpc_inference_port"; private static final String TS_GRPC_MANAGEMENT_PORT = "grpc_management_port"; + private static final String TS_GRPC_INFERENCE_MAX_CONNECTION_AGE_MS = + "grpc_inference_max_connection_age_ms"; + private static final String TS_GRPC_MANAGEMENT_MAX_CONNECTION_AGE_MS = + "grpc_management_max_connection_age_ms"; + private static final String TS_GRPC_INFERENCE_MAX_CONNECTION_AGE_GRACE_MS = + "grpc_inference_max_connection_age_grace_ms"; + private static final String TS_GRPC_MANAGEMENT_MAX_CONNECTION_AGE_GRACE_MS = + "grpc_management_max_connection_age_grace_ms"; private static final String TS_ENABLE_GRPC_SSL = "enable_grpc_ssl"; private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; private static final String TS_INITIAL_DISTRIBUTION_PORT = "initial_distribution_port"; @@ -384,6 +392,30 @@ public int getGRPCPort(ConnectorType connectorType) throws IllegalArgumentExcept return Integer.parseInt(port); } + public long getGRPCMaxConnectionAge(ConnectorType connectorType) + throws IllegalArgumentException { + if (connectorType == ConnectorType.MANAGEMENT_CONNECTOR) { + return getLongProperty(TS_GRPC_MANAGEMENT_MAX_CONNECTION_AGE_MS, Long.MAX_VALUE); + } else if (connectorType == ConnectorType.INFERENCE_CONNECTOR) { + return getLongProperty(TS_GRPC_INFERENCE_MAX_CONNECTION_AGE_MS, Long.MAX_VALUE); + } else { + throw new IllegalArgumentException( + "Connector type not supported by gRPC: " + connectorType); + } + } + + public long getGRPCMaxConnectionAgeGrace(ConnectorType connectorType) + throws IllegalArgumentException { + if (connectorType == ConnectorType.MANAGEMENT_CONNECTOR) { + return getLongProperty(TS_GRPC_MANAGEMENT_MAX_CONNECTION_AGE_GRACE_MS, Long.MAX_VALUE); + } else if (connectorType == ConnectorType.INFERENCE_CONNECTOR) { + return getLongProperty(TS_GRPC_INFERENCE_MAX_CONNECTION_AGE_GRACE_MS, Long.MAX_VALUE); + } else { + throw new IllegalArgumentException( + "Connector type not supported by gRPC: " + connectorType); + } + } + public boolean isOpenInferenceProtocol() { String inferenceProtocol = System.getenv("TS_OPEN_INFERENCE_PROTOCOL"); if (inferenceProtocol != null && inferenceProtocol != "") { @@ -795,6 +827,14 @@ private int getIntProperty(String key, int def) { return Integer.parseInt(value); } + private long getLongProperty(String key, long def) { + String value = prop.getProperty(key); + if (value == null) { + return def; + } + return Long.parseLong(value); + } + public int getDefaultResponseTimeout() { return Integer.parseInt(prop.getProperty(TS_DEFAULT_RESPONSE_TIMEOUT, "120")); } diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 222a5a837d..4ee9b0ee24 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1237,3 +1237,4 @@ SamplingParams lora vllm sql +TimeUnit