From 6cda8f76c9887461fa13980df853f4f6d46d38b7 Mon Sep 17 00:00:00 2001 From: townwish4git Date: Fri, 14 Feb 2025 14:49:35 +0800 Subject: [PATCH] fix(tensorboradx): fix invalid `write_to_disk` for `add_scalars` --- .../training/cogvideox_text_to_video_lora.py | 7 ++++--- .../training/cogvideox_text_to_video_sft.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_lora.py b/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_lora.py index 4a54e019e8..5479cf0069 100644 --- a/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_lora.py +++ b/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_lora.py @@ -575,9 +575,10 @@ def optimizer_state_filter(param_name: str): logs = {"loss": loss.item(), "lr": last_lr.item()} progress_bar.set_postfix(**logs) - for tracker_name, tracker in trackers.items(): - if tracker_name == "tensorboard": - tracker.add_scalars("train", logs, global_step) + if is_master(args): + for tracker_name, tracker in trackers.items(): + if tracker_name == "tensorboard": + tracker.add_scalars("train", logs, global_step) if global_step >= args.max_train_steps: break diff --git a/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_sft.py b/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_sft.py index 3ae43fed3a..d7b64ee80a 100644 --- a/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_sft.py +++ b/examples/diffusers/cogvideox_factory/training/cogvideox_text_to_video_sft.py @@ -554,9 +554,10 @@ def optimizer_state_filter(param_name: str): logs = {"loss": loss.item(), "lr": last_lr.item()} progress_bar.set_postfix(**logs) - for tracker_name, tracker in trackers.items(): - if tracker_name == "tensorboard": - tracker.add_scalars("train", logs, global_step) + if is_master(args): + for tracker_name, tracker in trackers.items(): + if tracker_name == "tensorboard": + tracker.add_scalars("train", logs, global_step) if global_step >= args.max_train_steps: break