File tree 1 file changed +8
-3
lines changed
src/diffusers/pipelines/wan
1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -220,8 +220,13 @@ def _get_t5_prompt_embeds(
220
220
221
221
return prompt_embeds
222
222
223
- def encode_image (self , image : PipelineImageInput ):
224
- image = self .image_processor (images = image , return_tensors = "pt" ).to (self .device )
223
+ def encode_image (
224
+ self ,
225
+ image : PipelineImageInput ,
226
+ device : Optional [torch .device ] = None ,
227
+ ):
228
+ device = device or self ._execution_device
229
+ image = self .image_processor (images = image , return_tensors = "pt" ).to (device )
225
230
image_embeds = self .image_encoder (** image , output_hidden_states = True )
226
231
return image_embeds .hidden_states [- 2 ]
227
232
@@ -587,7 +592,7 @@ def __call__(
587
592
if negative_prompt_embeds is not None :
588
593
negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
589
594
590
- image_embeds = self .encode_image (image )
595
+ image_embeds = self .encode_image (image , device )
591
596
image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
592
597
image_embeds = image_embeds .to (transformer_dtype )
593
598
You can’t perform that action at this time.
0 commit comments