From f1a131f2211d04eaa257f32607a45e79af261baf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Jan 2025 14:14:02 +0530 Subject: [PATCH] fix: model card info. --- finetrainers/trainer.py | 7 +++++++ finetrainers/utils/hub_utils.py | 15 +++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 153e4069..56c4866d 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -863,6 +863,13 @@ def validate(self, step: int, final_validation: bool = False) -> None: if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") + if accelerator.is_main_process: + save_model_card( + args=self.args, + repo_id=self.state.repo_id, + videos=None, + validation_prompts=None, + ) return self.transformer.eval() diff --git a/finetrainers/utils/hub_utils.py b/finetrainers/utils/hub_utils.py index ea1a16eb..ef865407 100644 --- a/finetrainers/utils/hub_utils.py +++ b/finetrainers/utils/hub_utils.py @@ -28,17 +28,20 @@ def save_model_card( } ) + training_type = "Full" if args.training_type == "full-finetune" else "LoRA" model_description = f""" -# LoRA Finetune +# {training_type} Finetune ## Model description -This is a lora finetune of model: `{args.pretrained_model_name_or_path}`. +This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`. The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). +`id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.) + ## Download model [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. @@ -53,7 +56,7 @@ def save_model_card( For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. """ - if wandb.run.url: + if wandb.run and wandb.run.url: model_description += f""" Find out the wandb run URL and training configurations [here]({wandb.run.url}). """ @@ -69,9 +72,13 @@ def save_model_card( "text-to-video", "diffusers-training", "diffusers", - "lora", + "finetrainers", "template:sd-lora", ] + if training_type == "Full": + tags.append("full-finetune") + else: + tags.append("lora") model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(args.output_dir, "README.md"))