Skip to content

Commit 14e8d6f

Browse files
authored
refactor(kserve): make the torchserve configuration loading more resiliant (#2995)
The many defaults defined cannot currently be used as if the matching keys are not present in the properties file, it's loading will fail. This patch fixes that and also ignores lines starting with # as they should be.
1 parent 2e26323 commit 14e8d6f

File tree

1 file changed

+21
-31
lines changed

1 file changed

+21
-31
lines changed

kubernetes/kserve/kserve_wrapper/__main__.py

+21-31
Original file line numberDiff line numberDiff line change
@@ -28,52 +28,42 @@ def parse_config():
2828
model_store: the path in which the .mar file resides
2929
"""
3030
separator = "="
31-
keys = {}
31+
ts_configuration = {}
3232
config_path = os.environ.get("CONFIG_PATH", DEFAULT_CONFIG_PATH)
3333

3434
logging.info(f"Wrapper: loading configuration from {config_path}")
3535

3636
with open(config_path) as f:
3737
for line in f:
38-
if separator in line:
39-
# Find the name and value by splitting the string
40-
name, value = line.split(separator, 1)
41-
42-
# Assign key value pair to dict
43-
# strip() removes white space from the ends of strings
44-
keys[name.strip()] = value.strip()
45-
46-
keys["model_snapshot"] = json.loads(keys["model_snapshot"])
47-
inference_address, management_address, grpc_inference_port, model_store = (
48-
keys["inference_address"],
49-
keys["management_address"],
50-
keys["grpc_inference_port"],
51-
keys["model_store"],
38+
if not line.startswith("#"):
39+
if separator in line:
40+
name, value = line.split(separator, 1)
41+
ts_configuration[name.strip()] = value.strip()
42+
43+
ts_configuration["model_snapshot"] = json.loads(
44+
ts_configuration.get("model_snapshot", "{}")
5245
)
5346

54-
models = keys["model_snapshot"]["models"]
55-
model_names = []
47+
inference_address = ts_configuration.get(
48+
"inference_address", DEFAULT_INFERENCE_ADDRESS
49+
)
50+
management_address = ts_configuration.get(
51+
"management_address", DEFAULT_MANAGEMENT_ADDRESS
52+
)
53+
grpc_inference_port = ts_configuration.get(
54+
"grpc_inference_port", DEFAULT_GRPC_INFERENCE_PORT
55+
)
56+
model_store = ts_configuration.get("model_store", DEFAULT_MODEL_STORE)
5657

5758
# Get all the model_names
58-
for model, value in models.items():
59-
model_names.append(model)
59+
model_names = ts_configuration["model_snapshot"].get("models", {}).keys()
6060

61-
if not inference_address:
62-
inference_address = DEFAULT_INFERENCE_ADDRESS
6361
if not model_names:
6462
model_names = [DEFAULT_MODEL_NAME]
65-
if not inference_address:
66-
inference_address = DEFAULT_INFERENCE_ADDRESS
67-
if not management_address:
68-
management_address = DEFAULT_MANAGEMENT_ADDRESS
63+
6964
inf_splits = inference_address.split(":")
70-
if not grpc_inference_port:
71-
grpc_inference_address = inf_splits[1] + ":" + DEFAULT_GRPC_INFERENCE_PORT
72-
else:
73-
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
65+
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
7466
grpc_inference_address = grpc_inference_address.replace("/", "")
75-
if not model_store:
76-
model_store = DEFAULT_MODEL_STORE
7767

7868
logging.info(
7969
"Wrapper : Model names %s, inference address %s, management address %s, grpc_inference_address, %s, model store %s",

0 commit comments

Comments
 (0)