Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Mar 6, 2025
1 parent a803b4f commit 7d0795c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion finetrainers/models/cogview4/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def validation(
"output_type": "pil",
}
generation_kwargs = get_non_null_items(generation_kwargs)
image = pipeline(**generation_kwargs).frames[0]
image = pipeline(**generation_kwargs).images[0]
return [data.ImageArtifact(value=image)]

def _save_lora_weights(
Expand Down
10 changes: 6 additions & 4 deletions finetrainers/trainer/sft_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import math
import os
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union

Expand Down Expand Up @@ -651,28 +652,29 @@ def _validate(self, step: int, final_validation: bool = False) -> None:
# TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
for index, (key, artifact) in enumerate(list(artifacts.items())):
assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))

time_, rank, ext = time.time(), parallel_backend.rank, artifact.file_extension
filename = "validation-" if not final_validation else "final-"
filename += f"{step}-{parallel_backend.rank}-{index}-{prompt_filename}.{artifact.file_extension}"
filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}"
output_filename = os.path.join(self.args.output_dir, filename)

if parallel_backend.is_main_process and artifact.file_extension == "mp4":
main_process_prompts_to_filenames[PROMPT] = filename

caption = f"{PROMPT} | (filename: {output_filename})"
if artifact.type == "image" and artifact.value is not None:
logger.debug(
f"Saving image from rank={parallel_backend.rank} to {output_filename}",
local_main_process_only=False,
)
artifact.value.save(output_filename)
all_processes_artifacts.append(wandb.Image(output_filename, caption=caption))
all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT))
elif artifact.type == "video" and artifact.value is not None:
logger.debug(
f"Saving video from rank={parallel_backend.rank} to {output_filename}",
local_main_process_only=False,
)
export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
all_processes_artifacts.append(wandb.Video(output_filename, caption=caption))
all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT))

# 3. Cleanup & log artifacts
parallel_backend.wait_for_everyone()
Expand Down

0 comments on commit 7d0795c

Please sign in to comment.