Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Feb 22, 2025
1 parent 5b4c2f0 commit 2830d1d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def diagonal_gaussian_distribution_sample(self, latent_dist: ms.Tensor) -> ms.Te

sample = ops.randn_like(mean, dtype=mean.dtype)
if self.enable_sequence_parallelism:
sample = self.broadcast(sample)
sample = self.broadcast((sample,))[0]
x = mean + std * sample

return x
Expand All @@ -824,14 +824,14 @@ def construct(self, videos, text_input_ids_or_prompt_embeds, image_rotary_emb=No
# Sample noise that will be added to the latents
noise = ops.randn_like(model_input, dtype=model_input.dtype)
if self.enable_sequence_parallelism:
noise = self.broadcast(noise)
noise = self.broadcast((noise,))[0]
batch_size, num_frames, num_channels, height, width = model_input.shape

# Sample a random timestep for each image
timesteps = ops.randint(0, self.scheduler_num_train_timesteps, (batch_size,), dtype=ms.int64)

if self.enable_sequence_parallelism:
timesteps = self.broadcast(timesteps)
timesteps = self.broadcast((timesteps,))[0]

# Rotary embeds is Prepared in dataset.
if self.use_rotary_positional_embeddings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def diagonal_gaussian_distribution_sample(self, latent_dist: ms.Tensor) -> ms.Te

sample = ops.randn_like(mean, dtype=mean.dtype)
if self.enable_sequence_parallelism:
sample = self.broadcast(sample)
sample = self.broadcast((sample,))[0]
x = mean + std * sample

return x
Expand All @@ -848,13 +848,13 @@ def construct(self, videos, text_input_ids_or_prompt_embeds, image_rotary_emb=No
# Sample noise that will be added to the latents
noise = ops.randn_like(model_input, dtype=model_input.dtype)
if self.enable_sequence_parallelism:
noise = self.broadcast(noise)
noise = self.broadcast((noise,))[0]
batch_size, num_frames, num_channels, height, width = model_input.shape

# Sample a random timestep for each image
timesteps = ops.randint(0, self.scheduler_num_train_timesteps, (batch_size,), dtype=ms.int64)
if self.enable_sequence_parallelism:
timesteps = self.broadcast(timesteps)
timesteps = self.broadcast((timesteps,))[0]

# Rotary embeds is Prepared in dataset.
if self.use_rotary_positional_embeddings:
Expand Down

0 comments on commit 2830d1d

Please sign in to comment.