Skip to content

Commit

Permalink
CogView4 ModelSpec (#297)
Browse files Browse the repository at this point in the history
* update

* update

* update

* fix

* add comment explaining shifted_sigmas

* update
  • Loading branch information
a-r-r-o-w authored Mar 7, 2025
1 parent 7285ccc commit cad195e
Show file tree
Hide file tree
Showing 14 changed files with 840 additions and 10 deletions.
162 changes: 162 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/bin/bash

set -e -x

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL="DEBUG"

# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences!
# BACKEND="accelerate"
BACKEND="ptd"

# In this setting, I'm using 2 GPUs on a 4-GPU node for training
NUM_GPUS=2
CUDA_VISIBLE_DEVICES="2,3"

# Check the JSON files for the expected JSON format
TRAINING_DATASET_CONFIG="examples/training/sft/cogview4/raider_white_tarot/training.json"
VALIDATION_DATASET_FILE="examples/training/sft/cogview4/raider_white_tarot/validation.json"

# Depending on how many GPUs you have available, choose your degree of parallelism and technique!
DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1"
FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1"
FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1"
HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1"

# Parallel arguments
parallel_cmd=(
$DDP_2
)

# Model arguments
model_cmd=(
--model_name "cogview4"
--pretrained_model_name_or_path "THUDM/CogView4-6B"
)

# Dataset arguments
# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same
# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. Since
# we're using 2 GPUs for training, we can split the data into 120 images per GPU and precompute
# all embeddings at once, instead of doing it on-the-fly which would be slower (the ideal usecase
# of not using `--precomputation_once` is when you're training on large datasets)
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 32
--precomputation_items 120
--precomputation_once
)

# Dataloader arguments
dataloader_cmd=(
--dataloader_num_workers 0
)

# Diffusion arguments
diffusion_cmd=(
--flow_weighting_scheme "logit_normal"
)

# Training arguments
# We target just the attention projections layers for LoRA training here.
# You can modify as you please and target any layer (regex is supported)
training_cmd=(
--training_type "lora"
--seed 42
--batch_size 1
--train_steps 5000
--rank 32
--lora_alpha 32
--target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)"
--gradient_accumulation_steps 1
--gradient_checkpointing
--checkpointing_steps 1000
--checkpointing_limit 2
# --resume_from_checkpoint 3000
--enable_slicing
--enable_tiling
)

# Optimizer arguments
optimizer_cmd=(
--optimizer "adamw"
--lr 3e-5
--lr_scheduler "constant_with_warmup"
--lr_warmup_steps 1000
--lr_num_cycles 1
--beta1 0.9
--beta2 0.99
--weight_decay 1e-4
--epsilon 1e-8
--max_grad_norm 1.0
)

# Validation arguments
validation_cmd=(
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 500
)

# Miscellaneous arguments
miscellaneous_cmd=(
--tracker_name "finetrainers-cogview4"
--output_dir "/raid/aryan/cogview4"
--init_timeout 600
--nccl_timeout 600
--report_to "wandb"
)

# Execute the training script
if [ "$BACKEND" == "accelerate" ]; then

ACCELERATE_CONFIG_FILE=""
if [ "$NUM_GPUS" == 1 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
elif [ "$NUM_GPUS" == 2 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml"
elif [ "$NUM_GPUS" == 4 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml"
elif [ "$NUM_GPUS" == 8 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml"
fi

accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"

elif [ "$BACKEND" == "ptd" ]; then

export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES

torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node=$NUM_GPUS \
--rdzv_backend c10d \
--rdzv_endpoint="localhost:0" \
train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"
fi

echo -ne "-------------------- Finished executing script --------------------\n\n"
34 changes: 34 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/training.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"datasets": [
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[1280, 720]
],
"reshape_mode": "bicubic",
"remove_common_llm_caption_prefixes": true
},
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[512, 512]
],
"reshape_mode": "center_crop",
"remove_common_llm_caption_prefixes": true
},
{
"data_root": "multimodalart/1920-raider-waite-tarot-public-domain",
"dataset_type": "image",
"id_token": "TRTCRD",
"image_resolution_buckets": [
[768, 768]
],
"reshape_mode": "center_crop",
"remove_common_llm_caption_prefixes": true
}
]
}
68 changes: 68 additions & 0 deletions examples/training/sft/cogview4/raider_white_tarot/validation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"data": [
{
"caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 1280,
"width": 720
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 512,
"width": 512
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 512,
"width": 512
},
{
"caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 768,
"width": 768
},
{
"caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles",
"image_path": null,
"video_path": null,
"num_inference_steps": 50,
"height": 768,
"width": 768
}
]
}
14 changes: 10 additions & 4 deletions finetrainers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from .models import ModelSpecification
from .models.cogvideox import CogVideoXModelSpecification
from .models.cogview4 import CogView4ModelSpecification
from .models.hunyuan_video import HunyuanVideoModelSpecification
from .models.ltx_video import LTXVideoModelSpecification
from .models.wan import WanModelSpecification


class ModelType(str, Enum):
COGVIDEOX = "cogvideox"
COGVIEW4 = "cogview4"
HUNYUAN_VIDEO = "hunyuan_video"
LTX_VIDEO = "ltx_video"
WAN = "wan"
Expand All @@ -21,6 +23,14 @@ class TrainingType(str, Enum):


SUPPORTED_MODEL_CONFIGS = {
ModelType.COGVIDEOX: {
TrainingType.LORA: CogVideoXModelSpecification,
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
},
ModelType.COGVIEW4: {
TrainingType.LORA: CogView4ModelSpecification,
TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
},
ModelType.HUNYUAN_VIDEO: {
TrainingType.LORA: HunyuanVideoModelSpecification,
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
Expand All @@ -29,10 +39,6 @@ class TrainingType(str, Enum):
TrainingType.LORA: LTXVideoModelSpecification,
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
},
ModelType.COGVIDEOX: {
TrainingType.LORA: CogVideoXModelSpecification,
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
},
ModelType.WAN: {
TrainingType.LORA: WanModelSpecification,
TrainingType.FULL_FINETUNE: WanModelSpecification,
Expand Down
26 changes: 25 additions & 1 deletion finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,17 @@ def __iter__(self):
for sample in iter(self.dataset):
if self.dataset_type == "image":
if self.image_resolution_buckets:
sample["_original_num_frames"] = 1
sample["_original_height"] = sample["image"].size(1)
sample["_original_width"] = sample["image"].size(2)
sample["image"] = FF.resize_to_nearest_bucket_image(
sample["image"], self.image_resolution_buckets, self.reshape_mode
)
elif self.dataset_type == "video":
if self.video_resolution_buckets:
sample["_original_num_frames"] = sample["video"].size(0)
sample["_original_height"] = sample["video"].size(2)
sample["_original_width"] = sample["video"].size(3)
sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
sample["video"], self.video_resolution_buckets, self.reshape_mode
)
Expand Down Expand Up @@ -751,9 +757,27 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
else:

has_tar_files = any(file.endswith(".tar") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite)

# TODO(aryan): handle parquet
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
try:
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
dataset = ImageFolderDataset(dataset_root, infinite=infinite)
else:
dataset = VideoFolderDataset(dataset_root, infinite=infinite)
return dataset
except Exception:
pass

raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")


def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
Expand Down
2 changes: 1 addition & 1 deletion finetrainers/functional/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso


def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]


def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
Expand Down
1 change: 1 addition & 0 deletions finetrainers/models/cogview4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base_specification import CogView4ModelSpecification
Loading

0 comments on commit cad195e

Please sign in to comment.