Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for RuntimeError: Type not yet supported: typing.Literal['sgd', 'adamw', 'adam'] #63

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Optional

import rich
import torch
Expand All @@ -16,7 +16,7 @@
DEFAULT_FLORENCE2_MODEL_REVISION,
DEVICE,
)
from maestro.trainer.models.florence_2.core import Configuration, LoraInitLiteral
from maestro.trainer.models.florence_2.core import Configuration
from maestro.trainer.models.florence_2.core import evaluate as florence2_evaluate
from maestro.trainer.models.florence_2.core import train as florence2_train

Expand Down Expand Up @@ -70,15 +70,15 @@ def train(
typer.Option("--epochs", help="Number of training epochs"),
] = 10,
optimizer: Annotated[
Literal["sgd", "adamw", "adam"],
str,
typer.Option("--optimizer", help="Optimizer to use for training"),
] = "adamw",
lr: Annotated[
float,
typer.Option("--lr", help="Learning rate for the optimizer"),
] = 1e-5,
lr_scheduler: Annotated[
Literal["linear", "cosine", "polynomial"],
str,
typer.Option("--lr_scheduler", help="Learning rate scheduler"),
] = "linear",
batch_size: Annotated[
Expand Down Expand Up @@ -110,15 +110,15 @@ def train(
typer.Option("--lora_dropout", help="Dropout probability for LoRA layers"),
] = 0.05,
bias: Annotated[
Literal["none", "all", "lora_only"],
str,
typer.Option("--bias", help="Which bias to train"),
] = "none",
use_rslora: Annotated[
bool,
typer.Option("--use_rslora/--no_use_rslora", help="Whether to use RSLoRA"),
] = True,
init_lora_weights: Annotated[
Union[bool, LoraInitLiteral],
str,
typer.Option("--init_lora_weights", help="How to initialize LoRA weights"),
] = "gaussian",
output_dir: Annotated[
Expand All @@ -138,19 +138,19 @@ def train(
device=torch.device(device),
cache_dir=cache_dir,
epochs=epochs,
optimizer=optimizer,
optimizer=optimizer, # type: ignore
lr=lr,
lr_scheduler=lr_scheduler,
lr_scheduler=lr_scheduler, # type: ignore
batch_size=batch_size,
val_batch_size=val_batch_size,
num_workers=num_workers,
val_num_workers=val_num_workers,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias=bias,
bias=bias, # type: ignore
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
init_lora_weights=init_lora_weights, # type: ignore
output_dir=output_dir,
metrics=metric_objects,
)
Expand Down