Skip to content

Commit 9b550d1

Browse files
authored
[Misc] Support docker for the latest vllm integration (LMCache#316)
* add docker-related stuff * remove comments * fix format * fix bash.sh to include docker patch
1 parent 00e9da9 commit 9b550d1

10 files changed

+340
-1
lines changed

docker/Dockerfile

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
2+
# to run the OpenAI compatible server.
3+
4+
# Please update any changes made here to
5+
# docs/source/dev/dockerfile/dockerfile.rst and
6+
# docs/source/assets/dev/dockerfile-stages-dependency.png
7+
8+
ARG CUDA_VERSION=12.4.1
9+
#################### BASE BUILD IMAGE ####################
10+
# prepare basic build environment
11+
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
12+
ARG CUDA_VERSION=12.4.1
13+
ARG PYTHON_VERSION=3.12
14+
ENV DEBIAN_FRONTEND=noninteractive
15+
16+
# Install Python and other dependencies
17+
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
18+
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
19+
&& apt-get update -y \
20+
&& apt-get install -y ccache software-properties-common git curl sudo \
21+
&& add-apt-repository ppa:deadsnakes/ppa \
22+
&& apt-get update -y \
23+
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
24+
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
25+
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
26+
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
27+
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
28+
&& python3 --version && python3 -m pip --version
29+
30+
# Workaround for https://github.com/openai/triton/issues/2507 and
31+
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
32+
# this won't be needed for future versions of this docker image
33+
# or future versions of triton.
34+
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
35+
36+
WORKDIR /workspace
37+
38+
# install build and runtime dependencies
39+
COPY requirements-common.txt requirements-common.txt
40+
COPY requirements-cuda.txt requirements-cuda.txt
41+
RUN --mount=type=cache,target=/root/.cache/pip \
42+
python3 -m pip install -r requirements-cuda.txt
43+
44+
45+
# cuda arch list used by torch
46+
# can be useful for both `dev` and `test`
47+
# explicitly set the list to avoid issues with torch 2.2
48+
# see https://github.com/pytorch/pytorch/pull/123243
49+
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
50+
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
51+
# Override the arch list for flash-attn to reduce the binary size
52+
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
53+
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
54+
#################### BASE BUILD IMAGE ####################
55+
56+
#################### WHEEL BUILD IMAGE ####################
57+
FROM base AS build
58+
59+
# install build dependencies
60+
COPY requirements-build.txt requirements-build.txt
61+
62+
# max jobs used by Ninja to build extensions
63+
ARG max_jobs=2
64+
ENV MAX_JOBS=${max_jobs}
65+
# number of threads used by nvcc
66+
ARG nvcc_threads=8
67+
ENV NVCC_THREADS=$nvcc_threads
68+
69+
70+
RUN --mount=type=cache,target=/root/.cache/pip \
71+
python3 -m pip install -r requirements-build.txt
72+
73+
ARG LMCACHE_COMMIT_ID=1
74+
75+
RUN git clone https://github.com/LMCache/LMCache.git
76+
RUN git clone https://github.com/LMCache/torchac_cuda.git
77+
78+
79+
WORKDIR /workspace/LMCache
80+
RUN --mount=type=cache,target=/root/.cache/ccache \
81+
--mount=type=cache,target=/root/.cache/pip \
82+
python3 setup.py bdist_wheel --dist-dir=dist_lmcache
83+
84+
WORKDIR /workspace/torchac_cuda
85+
RUN --mount=type=cache,target=/root/.cache/ccache \
86+
--mount=type=cache,target=/root/.cache/pip \
87+
python3 setup.py bdist_wheel --dist-dir=/workspace/LMCache/dist_lmcache
88+
89+
90+
#################### vLLM installation IMAGE ####################
91+
# Install torchac_cuda wheel into the vLLM image
92+
FROM vllm/vllm-openai:v0.6.6.post1 AS vllm-openai
93+
RUN --mount=type=bind,from=build,src=/workspace/LMCache/dist_lmcache,target=/vllm-workspace/dist_lmcache \
94+
--mount=type=cache,target=/root/.cache/pip \
95+
pip install dist_lmcache/*.whl --verbose
96+
97+
# Copy lmc_connector patch into vllm
98+
COPY patches/factory.py \
99+
/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_connector/
100+
COPY patches/lmcache_connector.py \
101+
/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_connector/
102+
103+
# Use diff if file is too large
104+
COPY patches/parallel_state.patch \
105+
/usr/local/lib/python3.12/dist-packages/vllm/distributed/
106+
COPY patches/config.patch \
107+
/usr/local/lib/python3.12/dist-packages/vllm/
108+
109+
RUN patch /usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py \
110+
/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.patch
111+
RUN patch /usr/local/lib/python3.12/dist-packages/vllm/config.py \
112+
/usr/local/lib/python3.12/dist-packages/vllm/config.patch
113+
114+
115+
ENTRYPOINT ["vllm", "serve"]

docker/example_run.sh

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
IMAGE=<IMAGE_NAME>:<TAG>
2+
docker run --runtime nvidia --gpus all \
3+
--env "HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN>" \
4+
--env "LMCACHE_USE_EXPERIMENTAL=True" \
5+
--env "chunk_size=256" \
6+
--env "local_cpu=True" \
7+
--env "max_local_cpu_size=5" \
8+
-v ~/.cache/huggingface:/root/.cache/huggingface \
9+
--network host \
10+
--entrypoint "/usr/local/bin/vllm" \
11+
$IMAGE \
12+
serve mistralai/Mistral-7B-Instruct-v0.2 --kv-transfer-config \
13+
'{"kv_connector":"LMCacheConnector","kv_role":"kv_both"}' \
14+
--enable-chunked-prefill false

docker/patch/config.patch

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
--- original/config.py 2025-01-19 20:05:02.376220126 -0600
2+
+++ config.py 2025-01-19 20:01:35.864391306 -0600
3+
@@ -2559,7 +2559,9 @@
4+
return KVTransferConfig.model_validate_json(cli_value)
5+
6+
def model_post_init(self, __context: Any) -> None:
7+
- supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
8+
+ supported_kv_connector = ["PyNcclConnector",
9+
+ "MooncakeConnector",
10+
+ "LMCacheConnector"]
11+
if all([
12+
self.kv_connector is not None, self.kv_connector
13+
not in supported_kv_connector

docker/patch/factory.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import TYPE_CHECKING
2+
3+
from .base import KVConnectorBase
4+
5+
if TYPE_CHECKING:
6+
from vllm.config import VllmConfig
7+
8+
9+
class KVConnectorFactory:
10+
11+
@staticmethod
12+
def create_connector(rank: int, local_rank: int,
13+
config: "VllmConfig") -> KVConnectorBase:
14+
supported_kv_connector = [
15+
"PyNcclConnector", "MooncakeConnector", "LMCacheConnector"
16+
]
17+
kv_connector = config.kv_transfer_config.kv_connector
18+
if kv_connector in supported_kv_connector:
19+
if kv_connector in ["PyNcclConnector", "MooncakeConnector"]:
20+
from .simple_connector import SimpleConnector
21+
return SimpleConnector(rank, local_rank, config)
22+
elif kv_connector in ["LMCacheConnector"]:
23+
from .lmcache_connector import LMCacheConnector
24+
return LMCacheConnector(rank, local_rank, config)
25+
else:
26+
raise ValueError(f"Unsupported connector type: "
27+
f"{config.kv_connector}")

docker/patch/lmcache_connector.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Simple KV Cache Connector for Distributed Machine Learning Inference
3+
4+
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
5+
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
6+
(2) offload and share KV caches. Only (2) is supported for now.
7+
"""
8+
9+
from typing import TYPE_CHECKING, List, Tuple, Union
10+
11+
import torch
12+
from vllm.config import VllmConfig
13+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
14+
from vllm.logger import init_logger
15+
from vllm.sequence import IntermediateTensors
16+
17+
if TYPE_CHECKING:
18+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
19+
20+
logger = init_logger(__name__)
21+
22+
23+
class LMCacheConnector(KVConnectorBase):
24+
25+
def __init__(
26+
self,
27+
rank: int,
28+
local_rank: int,
29+
config: VllmConfig,
30+
):
31+
32+
self.transfer_config = config.kv_transfer_config
33+
self.vllm_config = config
34+
35+
from lmcache.integration.vllm.vllm_adapter import (RetrieveStatus,
36+
StoreStatus,
37+
init_lmcache_engine,
38+
lmcache_retrieve_kv,
39+
lmcache_store_kv)
40+
41+
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
42+
self.transfer_config)
43+
44+
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
45+
self.engine = init_lmcache_engine(config.model_config,
46+
config.parallel_config,
47+
config.cache_config)
48+
49+
self.model_config = config.model_config
50+
self.parallel_config = config.parallel_config
51+
self.cache_config = config.cache_config
52+
self.lmcache_retrieve_kv = lmcache_retrieve_kv
53+
self.lmcache_store_kv = lmcache_store_kv
54+
self.store_status = StoreStatus
55+
self.retrieve_status = RetrieveStatus
56+
57+
def recv_kv_caches_and_hidden_states(
58+
self, model_executable: torch.nn.Module,
59+
model_input: "ModelInputForGPUWithSamplingMetadata",
60+
kv_caches: List[torch.Tensor]
61+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
62+
"ModelInputForGPUWithSamplingMetadata"]:
63+
64+
# TODO(Jiayi): This shouldn't be none for disagg prefill
65+
hidden_or_intermediate_states = None
66+
67+
# TODO (Jiayi): Only normal prefill is supported for now
68+
retrieve_status = [self.retrieve_status.PREFILL]
69+
70+
model_input, bypass_model_exec = self.lmcache_retrieve_kv(
71+
model_executable, model_input, self.cache_config, kv_caches,
72+
retrieve_status)
73+
74+
return hidden_or_intermediate_states, bypass_model_exec, model_input
75+
76+
def send_kv_caches_and_hidden_states(
77+
self,
78+
model_executable: torch.nn.Module,
79+
model_input: "ModelInputForGPUWithSamplingMetadata",
80+
kv_caches: List[torch.Tensor],
81+
hidden_or_intermediate_states: Union[torch.Tensor,
82+
IntermediateTensors],
83+
) -> None:
84+
num_reqs = 0
85+
seq_group_list = model_input.sampling_metadata.seq_groups
86+
assert seq_group_list is not None
87+
for seq_group in seq_group_list:
88+
seq_ids = seq_group.seq_ids
89+
for seq_id in seq_ids:
90+
num_reqs += 1
91+
92+
# TODO (Jiayi): Only normal prefill is supported for now
93+
store_status = [self.store_status.PREFILL] * num_reqs
94+
self.lmcache_store_kv(
95+
self.model_config,
96+
self.parallel_config,
97+
model_executable,
98+
model_input,
99+
kv_caches,
100+
store_status,
101+
)
102+
103+
def close(self):
104+
self.engine.close()

docker/patch/parallel_state.patch

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
--- original/parallel_state.py 2025-01-19 20:05:02.012220433 -0600
2+
+++ parallel_state.py 2025-01-19 20:07:24.844098884 -0600
3+
@@ -1075,9 +1075,9 @@
4+
5+
if vllm_config.kv_transfer_config is None:
6+
return
7+
-
8+
+
9+
if all([
10+
- vllm_config.kv_transfer_config.need_kv_parallel_group,
11+
+ vllm_config.kv_transfer_config.is_kv_transfer_instance,
12+
_KV_TRANSFER is None
13+
]):
14+
_KV_TRANSFER = kv_transfer.KVTransferAgent(

docker/requirements-build.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Should be mirrored in pyproject.toml
2+
cmake>=3.26
3+
ninja
4+
packaging
5+
setuptools>=61
6+
setuptools-scm>=8
7+
torch==2.4.0
8+
wheel
9+
jinja2

docker/requirements-common.txt

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
psutil
2+
sentencepiece # Required for LLaMA tokenizer.
3+
numpy < 2.0.0
4+
requests
5+
tqdm
6+
py-cpuinfo
7+
transformers >= 4.45.0 # Required for Llama 3.2.
8+
tokenizers >= 0.19.1 # Required for Llama 3.
9+
protobuf # Required by LlamaTokenizer.
10+
fastapi < 0.113.0; python_version < '3.9'
11+
fastapi >= 0.114.1; python_version >= '3.9'
12+
aiohttp
13+
openai >= 1.40.0 # Ensure modern openai package (ensure types module present)
14+
uvicorn[standard]
15+
pydantic >= 2.9 # Required for fastapi >= 0.113.0
16+
pillow # Required for image processing
17+
prometheus_client >= 0.18.0
18+
prometheus-fastapi-instrumentator >= 7.0.0
19+
tiktoken >= 0.6.0 # Required for DBRX tokenizer
20+
lm-format-enforcer == 0.10.6
21+
outlines >= 0.0.43, < 0.1
22+
typing_extensions >= 4.10
23+
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
24+
partial-json-parser # used for parsing partial JSON outputs
25+
pyzmq
26+
msgspec
27+
gguf == 0.10.0
28+
importlib_metadata
29+
mistral_common >= 1.4.3
30+
pyyaml
31+
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
32+
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
33+
einops # Required for Qwen2-VL.

docker/requirements-cuda.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Common dependencies
2+
-r requirements-common.txt
3+
4+
# Dependencies for NVIDIA GPUs
5+
ray >= 2.9
6+
nvidia-ml-py # for pynvml package
7+
torch == 2.4.0
8+
# These must be updated alongside torch
9+
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
10+
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0

format.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ if [[ "$1" == '--files' ]]; then
193193
# If `--all` is passed, then any further arguments are ignored and the
194194
# entire python directory is linted.
195195
elif [[ "$1" == '--all' ]]; then
196-
lint lmcache tests
196+
lint lmcache tests docker
197197
else
198198
# Format only the files that changed in last commit.
199199
lint_changed

0 commit comments

Comments
 (0)