Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Mar 3, 2025
1 parent da0d1eb commit 1155d0c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
9 changes: 8 additions & 1 deletion finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import datasets.data_files
import datasets.distributed
import datasets.exceptions
import huggingface_hub
import huggingface_hub.errors
import numpy as np
import PIL.Image
import torch
Expand Down Expand Up @@ -692,7 +694,12 @@ def initialize_dataset(
# 4. If there is a dataset name, we use the ImageWebDataset or VideoWebDataset class.
assert dataset_type in ["image", "video"]

if repo_exists(dataset_name_or_root, repo_type="dataset"):
try:
does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
except huggingface_hub.errors.HFValidationError:
does_repo_exist_on_hub = False

if does_repo_exist_on_hub:
return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite)
else:
return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
Expand Down
2 changes: 1 addition & 1 deletion finetrainers/utils/state_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]:
return checkpoints[-1] if len(checkpoints) > 0 else None

def _purge_stale_checkpoints(self) -> None:
if self.checkpointing_limit <= 0:
if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
return
checkpoints = sorted(
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
Expand Down
16 changes: 8 additions & 8 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ TODO(aryan): everything here needs to be improved.

```
# world_size=1 tests
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_degree_1___batch_size_1
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_degree_1___batch_size_2
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_1
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_2
# world_size=2 tests
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_degree_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_degree_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_shards_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_shards_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_training.py -k test___tp_degree_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___tp_degree_2___batch_size_2
# world_size=4 tests
torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_training.py -k test___dp_degree_2___dp_shards_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___dp_shards_2___batch_size_1
```
21 changes: 20 additions & 1 deletion tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py

import json
import os
import pathlib
import sys
Expand All @@ -17,6 +18,7 @@
sys.path.append(str(project_root))

from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger # noqa
from finetrainers.trainer.sft_trainer.config import SFTLowRankConfig, SFTFullRankConfig # noqa

from ..models.dummy.base_specification import DummyLTXVideoModelSpecification # noqa

Expand Down Expand Up @@ -47,12 +49,28 @@ def setUp(self):
prompt = f"A cat ruling the world - {i}"
f.write(f'{i}.mp4,"{prompt}"\n')

dataset_config = {
"datasets": [
{
"data_root": self.tmpdir.name,
"dataset_type": "video",
"id_token": "TEST",
"video_resolution_buckets": [[self.num_frames, self.height, self.width]],
"reshape_mode": "bicubic",
}
]
}

self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json"
with open(self.dataset_config_filename.as_posix(), "w") as f:
json.dump(dataset_config, f)

def tearDown(self):
self.tmpdir.cleanup()

def get_base_args(self) -> BaseArgs:
args = BaseArgs()
args.data_root = self.tmpdir.name
args.dataset_config = self.dataset_config_filename.as_posix()
args.train_steps = 10
args.batch_size = 1
args.gradient_checkpointing = True
Expand All @@ -76,6 +94,7 @@ def get_args(self) -> BaseArgs:
args.parallel_backend = "ptd"
args.training_type = TrainingType.LORA
args.rank = 4
args.lora_alpha = 4
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
return args

Expand Down

0 comments on commit 1155d0c

Please sign in to comment.