@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
56
56
"""The type of the content part."""
57
57
58
58
59
+ class ChatCompletionContentPartImageEmbedsParam (TypedDict , total = False ):
60
+ image_embeds : Required [Union [str , dict [str , str ]]]
61
+ """
62
+ The image embeddings. It can be either:
63
+ - A single base64 string.
64
+ - A dictionary where each value is a base64 string.
65
+ """
66
+ type : Required [Literal ["image_embeds" ]]
67
+ """The type of the content part."""
68
+
69
+
59
70
class VideoURL (TypedDict , total = False ):
60
71
url : Required [str ]
61
72
"""
@@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
109
120
ChatCompletionContentPartInputAudioParam ,
110
121
ChatCompletionContentPartVideoParam , ChatCompletionContentPartRefusalParam ,
111
122
CustomChatCompletionContentSimpleImageParam ,
123
+ ChatCompletionContentPartImageEmbedsParam ,
112
124
CustomChatCompletionContentSimpleAudioParam ,
113
125
CustomChatCompletionContentSimpleVideoParam , str ]
114
126
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
350
362
return detected_format
351
363
352
364
353
- ModalityStr = Literal ["image" , "audio" , "video" ]
365
+ ModalityStr = Literal ["image" , "audio" , "video" , "image_embeds" ]
354
366
_T = TypeVar ("_T" )
355
367
356
368
@@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr,
391
403
hf_config = self ._model_config .hf_config
392
404
model_type = hf_config .model_type
393
405
394
- if modality == "image" :
406
+ if modality in [ "image" , "image_embeds" ] :
395
407
if model_type == "phi3_v" :
396
408
# Workaround since this token is not defined in the tokenizer
397
409
return f"<|image_{ current_count } |>"
@@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser":
470
482
class MultiModalItemTracker (BaseMultiModalItemTracker [object ]):
471
483
472
484
def all_mm_data (self ) -> Optional [MultiModalDataDict ]:
473
- if self ._items_by_modality :
474
- return dict (self ._items_by_modality )
475
-
476
- return None
485
+ if not self ._items_by_modality :
486
+ return None
487
+ mm_inputs = {}
488
+ items_by_modality = dict (self ._items_by_modality )
489
+ if "image" in items_by_modality and "image_embeds" in items_by_modality :
490
+ raise ValueError (\
491
+ "Mixing raw image and embedding inputs is not allowed" )
492
+
493
+ if "image_embeds" in items_by_modality :
494
+ image_embeds_lst = items_by_modality ["image_embeds" ]
495
+ if len (image_embeds_lst ) > 1 :
496
+ raise ValueError (\
497
+ "Only one message can have {'type': 'image_embeds'}" )
498
+ mm_inputs ["image" ] = image_embeds_lst [0 ]
499
+ elif "image" in items_by_modality :
500
+ mm_inputs ["image" ] = items_by_modality ["image" ] # A list of images
501
+ elif "audio" in items_by_modality :
502
+ mm_inputs ["audio" ] = items_by_modality ["audio" ] # A list of audios
503
+ elif "video" in items_by_modality :
504
+ mm_inputs ["video" ] = items_by_modality ["video" ] # A list of videos
505
+ return mm_inputs
477
506
478
507
def create_parser (self ) -> "BaseMultiModalContentParser" :
479
508
return MultiModalContentParser (self )
@@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser":
482
511
class AsyncMultiModalItemTracker (BaseMultiModalItemTracker [Awaitable [object ]]):
483
512
484
513
async def all_mm_data (self ) -> Optional [MultiModalDataDict ]:
485
- if self ._items_by_modality :
486
- return {
514
+ if not self ._items_by_modality :
515
+ return None
516
+ mm_inputs = {}
517
+ items_by_modality = {
487
518
modality : await asyncio .gather (* items )
488
519
for modality , items in self ._items_by_modality .items ()
489
520
}
490
521
491
- return None
522
+ if "image" in items_by_modality and "image_embeds" in items_by_modality :
523
+ raise ValueError (
524
+ "Mixing raw image and embedding inputs is not allowed" )
525
+
526
+ if "image_embeds" in items_by_modality :
527
+ image_embeds_lst = items_by_modality ["image_embeds" ]
528
+ if len (image_embeds_lst ) > 1 :
529
+ raise ValueError (
530
+ "Only one message can have {'type': 'image_embeds'}" )
531
+ mm_inputs ["image" ] = image_embeds_lst [0 ]
532
+ elif "image" in items_by_modality :
533
+ mm_inputs ["image" ] = items_by_modality ["image" ] # A list of images
534
+ elif "audio" in items_by_modality :
535
+ mm_inputs ["audio" ] = items_by_modality ["audio" ] # A list of audios
536
+ elif "video" in items_by_modality :
537
+ mm_inputs ["video" ] = items_by_modality ["video" ] # A list of videos
538
+ return mm_inputs
492
539
493
540
def create_parser (self ) -> "BaseMultiModalContentParser" :
494
541
return AsyncMultiModalContentParser (self )
@@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]:
513
560
def parse_image (self , image_url : str ) -> None :
514
561
raise NotImplementedError
515
562
563
+ @abstractmethod
564
+ def parse_image_embeds (self ,
565
+ image_embeds : Union [str , dict [str , str ]]) -> None :
566
+ raise NotImplementedError
567
+
516
568
@abstractmethod
517
569
def parse_audio (self , audio_url : str ) -> None :
518
570
raise NotImplementedError
@@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None:
543
595
placeholder = self ._tracker .add ("image" , image )
544
596
self ._add_placeholder (placeholder )
545
597
598
+ def parse_image_embeds (self ,
599
+ image_embeds : Union [str , dict [str , str ]]) -> None :
600
+ if isinstance (image_embeds , dict ):
601
+ embeds = {
602
+ k : self ._connector .fetch_image_embedding (v )
603
+ for k , v in image_embeds .items ()
604
+ }
605
+ placeholder = self ._tracker .add ("image_embeds" , embeds )
606
+
607
+ if isinstance (image_embeds , str ):
608
+ embedding = self ._connector .fetch_image_embedding (image_embeds )
609
+ placeholder = self ._tracker .add ("image_embeds" , embedding )
610
+
611
+ self ._add_placeholder (placeholder )
612
+
546
613
def parse_audio (self , audio_url : str ) -> None :
547
614
audio = self ._connector .fetch_audio (audio_url )
548
615
@@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None:
579
646
placeholder = self ._tracker .add ("image" , image_coro )
580
647
self ._add_placeholder (placeholder )
581
648
649
+ def parse_image_embeds (self ,
650
+ image_embeds : Union [str , dict [str , str ]]) -> None :
651
+ future : asyncio .Future [Union [str , dict [str , str ]]] = asyncio .Future ()
652
+
653
+ if isinstance (image_embeds , dict ):
654
+ embeds = {
655
+ k : self ._connector .fetch_image_embedding (v )
656
+ for k , v in image_embeds .items ()
657
+ }
658
+ future .set_result (embeds )
659
+
660
+ if isinstance (image_embeds , str ):
661
+ embedding = self ._connector .\
662
+ fetch_image_embedding (image_embeds )
663
+ future .set_result (embedding )
664
+
665
+ placeholder = self ._tracker .add ("image_embeds" , future )
666
+ self ._add_placeholder (placeholder )
667
+
582
668
def parse_audio (self , audio_url : str ) -> None :
583
669
audio_coro = self ._connector .fetch_audio_async (audio_url )
584
670
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
684
770
# No need to validate using Pydantic again
685
771
_TextParser = partial (cast , ChatCompletionContentPartTextParam )
686
772
_ImageParser = partial (cast , ChatCompletionContentPartImageParam )
773
+ _ImageEmbedsParser = partial (cast , ChatCompletionContentPartImageEmbedsParam )
687
774
_AudioParser = partial (cast , ChatCompletionContentPartAudioParam )
688
775
_InputAudioParser = partial (cast , ChatCompletionContentPartInputAudioParam )
689
776
_RefusalParser = partial (cast , ChatCompletionContentPartRefusalParam )
@@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
700
787
lambda part : _TextParser (part ).get ("text" , "" ),
701
788
"image_url" :
702
789
lambda part : _ImageParser (part ).get ("image_url" , {}).get ("url" , "" ),
790
+ "image_embeds" :
791
+ lambda part : _ImageEmbedsParser (part ).get ("image_embeds" , {}),
703
792
"audio_url" :
704
793
lambda part : _AudioParser (part ).get ("audio_url" , {}).get ("url" , "" ),
705
794
"input_audio" :
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
769
858
770
859
771
860
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text" , "refusal" , "image_url" ,
861
+ "image_embeds" ,
772
862
"audio_url" , "input_audio" , "video_url" )
773
863
774
864
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
843
933
str_content = cast (str , content )
844
934
mm_parser .parse_image (str_content )
845
935
return {'type' : 'image' } if wrap_dicts else None
846
-
936
+ if part_type == "image_embeds" :
937
+ content = cast (Union [str , dict [str , str ]], content )
938
+ mm_parser .parse_image_embeds (content )
939
+ return {'type' : 'image' } if wrap_dicts else None
847
940
if part_type == "audio_url" :
848
941
str_content = cast (str , content )
849
942
mm_parser .parse_audio (str_content )
0 commit comments