Skip to content

Commit b0746fa

Browse files
authoredMar 10, 2025
[Frontend] support image embeds (vllm-project#13955)
Signed-off-by: chaunceyjiang <[email protected]>
1 parent 60a98b2 commit b0746fa

File tree

4 files changed

+201
-12
lines changed

4 files changed

+201
-12
lines changed
 

‎docs/source/serving/multimodal_inputs.md

+66-1
Original file line numberDiff line numberDiff line change
@@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
462462

463463
### Embedding Inputs
464464

465-
TBD
465+
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
466+
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
467+
#### Image Embedding Inputs
468+
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
469+
The following example demonstrates how to pass image embeddings to the OpenAI server:
470+
471+
```python
472+
image_embedding = torch.load(...)
473+
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
474+
475+
buffer = io.BytesIO()
476+
torch.save(image_embedding, buffer)
477+
buffer.seek(0)
478+
binary_data = buffer.read()
479+
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
480+
481+
client = OpenAI(
482+
# defaults to os.environ.get("OPENAI_API_KEY")
483+
api_key=openai_api_key,
484+
base_url=openai_api_base,
485+
)
486+
487+
# Basic usage - this is equivalent to the LLaVA example for offline inference
488+
model = "llava-hf/llava-1.5-7b-hf"
489+
embeds = {
490+
"type": "image_embeds",
491+
"image_embeds": f"{base64_image_embedding}"
492+
}
493+
494+
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
495+
model = "Qwen/Qwen2-VL-2B-Instruct"
496+
embeds = {
497+
"type": "image_embeds",
498+
"image_embeds": {
499+
"image_embeds": f"{base64_image_embedding}" , # Required
500+
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
501+
},
502+
}
503+
model = "openbmb/MiniCPM-V-2_6"
504+
embeds = {
505+
"type": "image_embeds",
506+
"image_embeds": {
507+
"image_embeds": f"{base64_image_embedding}" , # Required
508+
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
509+
},
510+
}
511+
chat_completion = client.chat.completions.create(
512+
messages=[
513+
{"role": "system", "content": "You are a helpful assistant."},
514+
{"role": "user", "content": [
515+
{
516+
"type": "text",
517+
"text": "What's in this image?",
518+
},
519+
embeds,
520+
],
521+
},
522+
],
523+
model=model,
524+
)
525+
```
526+
527+
:::{note}
528+
Only one message can contain `{"type": "image_embeds"}`.
529+
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
530+
:::

‎vllm/entrypoints/chat_utils.py

+103-10
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
5656
"""The type of the content part."""
5757

5858

59+
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
60+
image_embeds: Required[Union[str, dict[str, str]]]
61+
"""
62+
The image embeddings. It can be either:
63+
- A single base64 string.
64+
- A dictionary where each value is a base64 string.
65+
"""
66+
type: Required[Literal["image_embeds"]]
67+
"""The type of the content part."""
68+
69+
5970
class VideoURL(TypedDict, total=False):
6071
url: Required[str]
6172
"""
@@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
109120
ChatCompletionContentPartInputAudioParam,
110121
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
111122
CustomChatCompletionContentSimpleImageParam,
123+
ChatCompletionContentPartImageEmbedsParam,
112124
CustomChatCompletionContentSimpleAudioParam,
113125
CustomChatCompletionContentSimpleVideoParam, str]
114126

@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
350362
return detected_format
351363

352364

353-
ModalityStr = Literal["image", "audio", "video"]
365+
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
354366
_T = TypeVar("_T")
355367

356368

@@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr,
391403
hf_config = self._model_config.hf_config
392404
model_type = hf_config.model_type
393405

394-
if modality == "image":
406+
if modality in ["image", "image_embeds"]:
395407
if model_type == "phi3_v":
396408
# Workaround since this token is not defined in the tokenizer
397409
return f"<|image_{current_count}|>"
@@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser":
470482
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
471483

472484
def all_mm_data(self) -> Optional[MultiModalDataDict]:
473-
if self._items_by_modality:
474-
return dict(self._items_by_modality)
475-
476-
return None
485+
if not self._items_by_modality:
486+
return None
487+
mm_inputs = {}
488+
items_by_modality = dict(self._items_by_modality)
489+
if "image" in items_by_modality and "image_embeds" in items_by_modality:
490+
raise ValueError(\
491+
"Mixing raw image and embedding inputs is not allowed")
492+
493+
if "image_embeds" in items_by_modality:
494+
image_embeds_lst = items_by_modality["image_embeds"]
495+
if len(image_embeds_lst) > 1:
496+
raise ValueError(\
497+
"Only one message can have {'type': 'image_embeds'}")
498+
mm_inputs["image"] = image_embeds_lst[0]
499+
elif "image" in items_by_modality:
500+
mm_inputs["image"] = items_by_modality["image"] # A list of images
501+
elif "audio" in items_by_modality:
502+
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
503+
elif "video" in items_by_modality:
504+
mm_inputs["video"] = items_by_modality["video"] # A list of videos
505+
return mm_inputs
477506

478507
def create_parser(self) -> "BaseMultiModalContentParser":
479508
return MultiModalContentParser(self)
@@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser":
482511
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
483512

484513
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
485-
if self._items_by_modality:
486-
return {
514+
if not self._items_by_modality:
515+
return None
516+
mm_inputs = {}
517+
items_by_modality = {
487518
modality: await asyncio.gather(*items)
488519
for modality, items in self._items_by_modality.items()
489520
}
490521

491-
return None
522+
if "image" in items_by_modality and "image_embeds" in items_by_modality:
523+
raise ValueError(
524+
"Mixing raw image and embedding inputs is not allowed")
525+
526+
if "image_embeds" in items_by_modality:
527+
image_embeds_lst = items_by_modality["image_embeds"]
528+
if len(image_embeds_lst) > 1:
529+
raise ValueError(
530+
"Only one message can have {'type': 'image_embeds'}")
531+
mm_inputs["image"] = image_embeds_lst[0]
532+
elif "image" in items_by_modality:
533+
mm_inputs["image"] = items_by_modality["image"] # A list of images
534+
elif "audio" in items_by_modality:
535+
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
536+
elif "video" in items_by_modality:
537+
mm_inputs["video"] = items_by_modality["video"] # A list of videos
538+
return mm_inputs
492539

493540
def create_parser(self) -> "BaseMultiModalContentParser":
494541
return AsyncMultiModalContentParser(self)
@@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]:
513560
def parse_image(self, image_url: str) -> None:
514561
raise NotImplementedError
515562

563+
@abstractmethod
564+
def parse_image_embeds(self,
565+
image_embeds: Union[str, dict[str, str]]) -> None:
566+
raise NotImplementedError
567+
516568
@abstractmethod
517569
def parse_audio(self, audio_url: str) -> None:
518570
raise NotImplementedError
@@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None:
543595
placeholder = self._tracker.add("image", image)
544596
self._add_placeholder(placeholder)
545597

598+
def parse_image_embeds(self,
599+
image_embeds: Union[str, dict[str, str]]) -> None:
600+
if isinstance(image_embeds, dict):
601+
embeds = {
602+
k: self._connector.fetch_image_embedding(v)
603+
for k, v in image_embeds.items()
604+
}
605+
placeholder = self._tracker.add("image_embeds", embeds)
606+
607+
if isinstance(image_embeds, str):
608+
embedding = self._connector.fetch_image_embedding(image_embeds)
609+
placeholder = self._tracker.add("image_embeds", embedding)
610+
611+
self._add_placeholder(placeholder)
612+
546613
def parse_audio(self, audio_url: str) -> None:
547614
audio = self._connector.fetch_audio(audio_url)
548615

@@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None:
579646
placeholder = self._tracker.add("image", image_coro)
580647
self._add_placeholder(placeholder)
581648

649+
def parse_image_embeds(self,
650+
image_embeds: Union[str, dict[str, str]]) -> None:
651+
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
652+
653+
if isinstance(image_embeds, dict):
654+
embeds = {
655+
k: self._connector.fetch_image_embedding(v)
656+
for k, v in image_embeds.items()
657+
}
658+
future.set_result(embeds)
659+
660+
if isinstance(image_embeds, str):
661+
embedding = self._connector.\
662+
fetch_image_embedding(image_embeds)
663+
future.set_result(embedding)
664+
665+
placeholder = self._tracker.add("image_embeds", future)
666+
self._add_placeholder(placeholder)
667+
582668
def parse_audio(self, audio_url: str) -> None:
583669
audio_coro = self._connector.fetch_audio_async(audio_url)
584670

@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
684770
# No need to validate using Pydantic again
685771
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
686772
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
773+
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
687774
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
688775
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
689776
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
@@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
700787
lambda part: _TextParser(part).get("text", ""),
701788
"image_url":
702789
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
790+
"image_embeds":
791+
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
703792
"audio_url":
704793
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
705794
"input_audio":
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
769858

770859

771860
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
861+
"image_embeds",
772862
"audio_url", "input_audio", "video_url")
773863

774864

@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
843933
str_content = cast(str, content)
844934
mm_parser.parse_image(str_content)
845935
return {'type': 'image'} if wrap_dicts else None
846-
936+
if part_type == "image_embeds":
937+
content = cast(Union[str, dict[str, str]], content)
938+
mm_parser.parse_image_embeds(content)
939+
return {'type': 'image'} if wrap_dicts else None
847940
if part_type == "audio_url":
848941
str_content = cast(str, content)
849942
mm_parser.parse_audio(str_content)

‎vllm/multimodal/image.py

+19
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,22 @@ def encode_base64(
134134
data = buffer.getvalue()
135135

136136
return base64.b64encode(data).decode('utf-8')
137+
138+
139+
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
140+
141+
def __init__(self) -> None:
142+
super().__init__()
143+
144+
def load_bytes(self, data: bytes) -> torch.Tensor:
145+
buffer = BytesIO(data)
146+
return torch.load(buffer, weights_only=True)
147+
148+
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
149+
return self.load_bytes(base64.b64decode(data))
150+
151+
def load_file(self, filepath: Path) -> torch.Tensor:
152+
return torch.load(filepath)
153+
154+
def encode_base64(self, media: torch.Tensor) -> str:
155+
return base64.b64encode(media.numpy()).decode('utf-8')

‎vllm/multimodal/utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import numpy.typing as npt
10+
import torch
1011
from PIL import Image
1112

1213
import vllm.envs as envs
@@ -16,7 +17,7 @@
1617

1718
from .audio import AudioMediaIO
1819
from .base import MediaIO
19-
from .image import ImageMediaIO
20+
from .image import ImageEmbeddingMediaIO, ImageMediaIO
2021
from .inputs import PlaceholderRange
2122
from .video import VideoMediaIO
2223

@@ -245,6 +246,17 @@ async def fetch_video_async(
245246
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
246247
)
247248

249+
def fetch_image_embedding(
250+
self,
251+
data: str,
252+
) -> torch.Tensor:
253+
"""
254+
Load image embedding from a URL.
255+
"""
256+
image_embedding_io = ImageEmbeddingMediaIO()
257+
258+
return image_embedding_io.load_base64("", data)
259+
248260

249261
global_media_connector = MediaConnector()
250262
"""The global :class:`MediaConnector` instance used by vLLM."""

0 commit comments

Comments
 (0)
Please sign in to comment.