Skip to content

Commit

Permalink
ready for lora pr
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianChen1129 committed Jan 14, 2025
1 parent db92841 commit a44972b
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,20 @@ def train_one_step(
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
training_guidance = torch.tensor([1000.0],
device=noisy_model_input.device,
dtype=torch.bfloat16)
with torch.autocast("cuda", torch.bfloat16):
model_pred = transformer(
noisy_model_input,
encoder_hidden_states,
timesteps,
encoder_attention_mask, # B, L
training_guidance,
return_dict=False,
)[0]
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor(
[1000.0],
device=noisy_model_input.device,
dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]

if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
Expand Down

0 comments on commit a44972b

Please sign in to comment.