Skip to content

Commit bf13d40

Browse files
authored
[core] Pass all driver env vars to ray workers unless excluded (vllm-project#14099)
Signed-off-by: Rui Qiao <[email protected]>
1 parent 989f4f4 commit bf13d40

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

vllm/executor/ray_distributed_executor.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import asyncio
4+
import json
45
import os
56
from collections import defaultdict
67
from dataclasses import dataclass
@@ -48,6 +49,24 @@ class RayWorkerMetaData:
4849

4950

5051
class RayDistributedExecutor(DistributedExecutorBase):
52+
"""Ray-based distributed executor"""
53+
54+
# These env vars are worker-specific, therefore are NOT copied
55+
# from the driver to the workers
56+
WORKER_SPECIFIC_ENV_VARS = {
57+
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
58+
}
59+
60+
config_home = envs.VLLM_CONFIG_ROOT
61+
# This file contains a list of env vars that should not be copied
62+
# from the driver to the Ray workers.
63+
non_carry_over_env_vars_file = os.path.join(
64+
config_home, "ray_non_carry_over_env_vars.json")
65+
if os.path.exists(non_carry_over_env_vars_file):
66+
with open(non_carry_over_env_vars_file) as f:
67+
non_carry_over_env_vars = set(json.load(f))
68+
else:
69+
non_carry_over_env_vars = set()
5170

5271
uses_ray: bool = True
5372

@@ -311,9 +330,9 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
311330

312331
# Environment variables to copy from driver to workers
313332
env_vars_to_copy = [
314-
"VLLM_ATTENTION_BACKEND", "TPU_CHIPS_PER_HOST_BOUNDS",
315-
"TPU_HOST_BOUNDS", "VLLM_USE_V1", "VLLM_TRACE_FUNCTION",
316-
"VLLM_TORCH_PROFILER_DIR", "VLLM_TEST_ENABLE_EP"
333+
v for v in envs.environment_variables
334+
if v not in self.WORKER_SPECIFIC_ENV_VARS
335+
and v not in self.non_carry_over_env_vars
317336
]
318337

319338
# Copy existing env vars to each worker's args
@@ -323,9 +342,14 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
323342
if name in os.environ:
324343
args[name] = os.environ[name]
325344

345+
logger.info("non_carry_over_env_vars from config: %s",
346+
self.non_carry_over_env_vars)
326347
logger.info(
327348
"Copying the following environment variables to workers: %s",
328349
[v for v in env_vars_to_copy if v in os.environ])
350+
logger.info(
351+
"If certain env vars should NOT be copied to workers, add them to "
352+
"%s file", self.non_carry_over_env_vars_file)
329353

330354
self._env_vars_for_all_workers = (
331355
all_args_to_update_environment_variables)

0 commit comments

Comments
 (0)