Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precomputation folder name based on model name #196

Merged
merged 4 commits into from
Jan 8, 2025

Conversation

a-r-r-o-w
Copy link
Owner

Currently, if we try to train two models with same datasets and enabling precomputation, it is not possible. This PR adds model_name to the folder so that the name clash does not happen

Tested on smol run for LTX:

Script
#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="disabled"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="2"

DATA_ROOT="/raid/aryan/video-dataset-disney"
CAPTION_COLUMN="prompts2.txt"
VIDEO_COLUMN="videos2.txt"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path a-r-r-o-w/LTX-Video-diffusers"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token afkx \
  --video_resolution_buckets 49x512x768 \
  --caption_dropout_p 0.00 \
  --precompute_conditions"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_weighting_scheme logit_normal"

# Training arguments
training_cmd="--training_type lora \
  --mixed_precision bf16 \
  --seed 42 \
  --batch_size 1 \
  --train_steps 10 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 1e-4 \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"afkx A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x480x768:::A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x480x768\" \
  --num_validation_videos 1 \
  --validation_steps 5"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --output_dir /raid/aryan/ltx-video \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_1.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul January 8, 2025 02:38
@@ -30,6 +29,8 @@
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm

import wandb
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to ignore this from ruff? If I have a wandb folder in the root directory from some training run, and then run make style, it just moves it down here for some reason

@@ -252,8 +252,14 @@ def collate_fn(batch):
"Caption dropout is not supported with precomputation yet. This will be supported in the future."
)

conditions_dir = Path(self.args.data_root) / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
latents_dir = Path(self.args.data_root) / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
conditions_dir = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nits):

  1. If we use the same model type but use a different checkpoint, it might still lead to some inconsistencies (in case we have different dtypes for the VAE, etc.). So, would prefer to also append the checkpoint here besides model_name.
  2. Would assign Path(self.args.data_root) / f"{self.args.model_name}_{PRECOMPUTED_DIR_NAME}" to a variable and reuse.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay yeah makes sense, will update

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe look into custom component based naming in the future because atm only loading ALL components from a single model_id is supported.

With the latest commit, the folder created looks something like: /raid/aryan/video-dataset-disney/ltx_video_a-r-r-o-w-LTX-Video-diffusers_precomputed/. Is this what you meant?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me thanks!

@a-r-r-o-w a-r-r-o-w merged commit dbffc80 into main Jan 8, 2025
1 check passed
@a-r-r-o-w a-r-r-o-w deleted the precompute-folder-name branch January 8, 2025 03:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants