Skip to content

Commit 5d970a4

Browse files
authored
WanI2V encode_image (#11164)
* WanI2V encode_image
1 parent de6a88c commit 5d970a4

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,13 @@ def _get_t5_prompt_embeds(
220220

221221
return prompt_embeds
222222

223-
def encode_image(self, image: PipelineImageInput):
224-
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
223+
def encode_image(
224+
self,
225+
image: PipelineImageInput,
226+
device: Optional[torch.device] = None,
227+
):
228+
device = device or self._execution_device
229+
image = self.image_processor(images=image, return_tensors="pt").to(device)
225230
image_embeds = self.image_encoder(**image, output_hidden_states=True)
226231
return image_embeds.hidden_states[-2]
227232

@@ -587,7 +592,7 @@ def __call__(
587592
if negative_prompt_embeds is not None:
588593
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
589594

590-
image_embeds = self.encode_image(image)
595+
image_embeds = self.encode_image(image, device)
591596
image_embeds = image_embeds.repeat(batch_size, 1, 1)
592597
image_embeds = image_embeds.to(transformer_dtype)
593598

0 commit comments

Comments
 (0)