Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CogView4 ModelSpec #297

Merged
merged 6 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/bin/bash

set -e -x

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL="DEBUG"

# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences!
# BACKEND="accelerate"
BACKEND="ptd"

# In this setting, I'm using 2 GPUs on a 4-GPU node for training
NUM_GPUS=2
CUDA_VISIBLE_DEVICES="2,3"

# Check the JSON files for the expected JSON format
TRAINING_DATASET_CONFIG="examples/training/sft/cogview4/raider_white_tarot/training.json"
VALIDATION_DATASET_FILE="examples/training/sft/cogview4/raider_white_tarot/validation.json"

# Depending on how many GPUs you have available, choose your degree of parallelism and technique!
DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1"
FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1"
FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1"
HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1"

# Parallel arguments
parallel_cmd=(
$DDP_2
)

# Model arguments
model_cmd=(
--model_name "cogview4"
--pretrained_model_name_or_path "THUDM/CogView4-6B"
)

# Dataset arguments
# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same
# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. Since
# we're using 2 GPUs for training, we can split the data into 120 images per GPU and precompute
# all embeddings at once, instead of doing it on-the-fly which would be slower (the ideal usecase
# of not using `--precomputation_once` is when you're training on large datasets)
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 32
--precomputation_items 120
--precomputation_once
)

# Dataloader arguments
dataloader_cmd=(
--dataloader_num_workers 0
)

# Diffusion arguments
diffusion_cmd=(
--flow_weighting_scheme "logit_normal"
)

# Training arguments
# We target just the attention projections layers for LoRA training here.
# You can modify as you please and target any layer (regex is supported)
training_cmd=(
--training_type "lora"
--seed 42
--batch_size 1
--train_steps 5000
--rank 32
--lora_alpha 32
--target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)"
--gradient_accumulation_steps 1
--gradient_checkpointing
--checkpointing_steps 1000
--checkpointing_limit 2
# --resume_from_checkpoint 3000
--enable_slicing
--enable_tiling
)

# Optimizer arguments
optimizer_cmd=(
--optimizer "adamw"
--lr 3e-5
--lr_scheduler "constant_with_warmup"
--lr_warmup_steps 1000
--lr_num_cycles 1
--beta1 0.9
--beta2 0.99
--weight_decay 1e-4
--epsilon 1e-8
--max_grad_norm 1.0
)

# Validation arguments
validation_cmd=(
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 500
)

# Miscellaneous arguments
miscellaneous_cmd=(
--tracker_name "finetrainers-cogview4"
--output_dir "/raid/aryan/cogview4"
--init_timeout 600
--nccl_timeout 600
--report_to "wandb"
)

# Execute the training script
if [ "$BACKEND" == "accelerate" ]; then

ACCELERATE_CONFIG_FILE=""
if [ "$NUM_GPUS" == 1 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
elif [ "$NUM_GPUS" == 2 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml"
elif [ "$NUM_GPUS" == 4 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml"
elif [ "$NUM_GPUS" == 8 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml"
fi

accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"

elif [ "$BACKEND" == "ptd" ]; then

export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES

torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node=$NUM_GPUS \
--rdzv_backend c10d \
--rdzv_endpoint="localhost:0" \
train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"
fi

echo -ne "-------------------- Finished executing script --------------------\n\n"
34 changes: 34 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/training.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"datasets": [
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[1280, 720]
],
"reshape_mode": "bicubic",
"remove_common_llm_caption_prefixes": true
},
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[512, 512]
],
"reshape_mode": "center_crop",
"remove_common_llm_caption_prefixes": true
},
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[768, 768]
],
"reshape_mode": "center_crop",
"remove_common_llm_caption_prefixes": true
}
]
}
68 changes: 68 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/validation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"data": [
{
"caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 512,
"width": 512
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 512,
"width": 512
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 768,
"width": 768
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 768,
"width": 768
}
]
}
14 changes: 10 additions & 4 deletions finetrainers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from .models import ModelSpecification
from .models.cogvideox import CogVideoXModelSpecification
from .models.cogview4 import CogView4ModelSpecification
from .models.hunyuan_video import HunyuanVideoModelSpecification
from .models.ltx_video import LTXVideoModelSpecification
from .models.wan import WanModelSpecification


class ModelType(str, Enum):
COGVIDEOX = "cogvideox"
COGVIEW4 = "cogview4"
HUNYUAN_VIDEO = "hunyuan_video"
LTX_VIDEO = "ltx_video"
WAN = "wan"
Expand All @@ -21,6 +23,14 @@ class TrainingType(str, Enum):


SUPPORTED_MODEL_CONFIGS = {
ModelType.COGVIDEOX: {
TrainingType.LORA: CogVideoXModelSpecification,
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
},
ModelType.COGVIEW4: {
TrainingType.LORA: CogView4ModelSpecification,
TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
},
ModelType.HUNYUAN_VIDEO: {
TrainingType.LORA: HunyuanVideoModelSpecification,
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
Expand All @@ -29,10 +39,6 @@ class TrainingType(str, Enum):
TrainingType.LORA: LTXVideoModelSpecification,
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
},
ModelType.COGVIDEOX: {
TrainingType.LORA: CogVideoXModelSpecification,
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
},
ModelType.WAN: {
TrainingType.LORA: WanModelSpecification,
TrainingType.FULL_FINETUNE: WanModelSpecification,
Expand Down
26 changes: 25 additions & 1 deletion finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,17 @@ def __iter__(self):
for sample in iter(self.dataset):
if self.dataset_type == "image":
if self.image_resolution_buckets:
sample["_original_num_frames"] = 1
sample["_original_height"] = sample["image"].size(1)
sample["_original_width"] = sample["image"].size(2)
sample["image"] = FF.resize_to_nearest_bucket_image(
sample["image"], self.image_resolution_buckets, self.reshape_mode
)
elif self.dataset_type == "video":
if self.video_resolution_buckets:
sample["_original_num_frames"] = sample["video"].size(0)
sample["_original_height"] = sample["video"].size(2)
sample["_original_width"] = sample["video"].size(3)
sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
sample["video"], self.video_resolution_buckets, self.reshape_mode
)
Expand Down Expand Up @@ -751,9 +757,27 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
else:

has_tar_files = any(file.endswith(".tar") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite)

# TODO(aryan): handle parquet
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
try:
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
dataset = ImageFolderDataset(dataset_root, infinite=infinite)
else:
dataset = VideoFolderDataset(dataset_root, infinite=infinite)
return dataset
except Exception:
pass

raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")


def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
Expand Down
2 changes: 1 addition & 1 deletion finetrainers/functional/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso


def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]


def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
Expand Down
1 change: 1 addition & 0 deletions finetrainers/models/cogview4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base_specification import CogView4ModelSpecification
Loading