Skip to content

Commit

Permalink
Merge pull request #173 from roboflow/feature/qwen_2_5_vl_object_dete…
Browse files Browse the repository at this point in the history
…ction_support

Qwen2.5-VL object detection fine-tuning support
  • Loading branch information
SkalskiP authored Feb 25, 2025
2 parents 1e408d3 + ef51c20 commit f130403
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 31 deletions.
4 changes: 2 additions & 2 deletions maestro/trainer/common/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class COCODataset(Dataset, BaseDetectionDataset):
login()
dataset = download_dataset("universe.roboflow.com/huyifei/tft-id/1", "coco")
ds = COOCDataset(
annotations_path=f"{dataset.location}/test/_annotations.jsonl",
ds = COCODataset(
annotations_path=f"{dataset.location}/test/_annotations.coco.json",
images_directory_path=f"{dataset.location}/test"
)
len(ds)
Expand Down
2 changes: 1 addition & 1 deletion maestro/trainer/models/florence_2/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def load_model(
model_id_or_path: str = DEFAULT_FLORENCE2_MODEL_ID,
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
device: str | torch.device = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.LORA,
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
) -> tuple[AutoProcessor, AutoModelForCausalLM]:
"""Loads a Florence 2 model and its associated processor.
Expand Down
15 changes: 15 additions & 0 deletions maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from maestro.trainer.common.utils.device import device_is_available, parse_device_spec
from maestro.trainer.common.utils.path import create_new_run_directory
from maestro.trainer.common.utils.seed import ensure_reproducibility
from maestro.trainer.logger import get_maestro_logger
from maestro.trainer.models.florence_2.checkpoints import (
DEFAULT_FLORENCE2_MODEL_ID,
DEFAULT_FLORENCE2_MODEL_REVISION,
Expand All @@ -38,6 +39,8 @@
from maestro.trainer.models.florence_2.inference import predict_with_inputs
from maestro.trainer.models.florence_2.loaders import evaluation_collate_fn, train_collate_fn

logger = get_maestro_logger()


@dataclass()
class Florence2Configuration:
Expand Down Expand Up @@ -162,6 +165,12 @@ def validation_step(self, batch, batch_idx):
device=self.config.device,
max_new_tokens=self.config.max_new_tokens,
)

if batch_idx == 0:
logger.info(f"sample valid prefix: {prefixes[0]}")
logger.info(f"sample valid suffix: {suffixes[0]}")
logger.info(f"sample generated suffix: {generated_suffixes[0]}")

for metric in self.config.metrics:
if isinstance(metric, MeanAveragePrecisionMetric):
predictions_list = []
Expand Down Expand Up @@ -250,6 +259,11 @@ def train(config: Florence2Configuration | dict) -> None:
detections_to_prefix_formatter=detections_to_prefix_formatter,
detections_to_suffix_formatter=detections_to_suffix_formatter,
)

_, train_entry = train_loader.dataset[0]
logger.info(f"sample train prefix: {train_entry['prefix']}")
logger.info(f"sample train suffix: {train_entry['suffix']}")

pl_module = Florence2Trainer(
processor=processor, model=model, train_loader=train_loader, valid_loader=valid_loader, config=config
)
Expand All @@ -259,6 +273,7 @@ def train(config: Florence2Configuration | dict) -> None:
max_epochs=config.epochs,
accumulate_grad_batches=config.accumulate_grad_batches,
check_val_every_n_epoch=1,
limit_val_batches=1,
log_every_n_steps=10,
callbacks=[save_checkpoint_callback],
)
Expand Down
2 changes: 1 addition & 1 deletion maestro/trainer/models/paligemma_2/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_model(
model_id_or_path: str = DEFAULT_PALIGEMMA2_MODEL_ID,
revision: str = DEFAULT_PALIGEMMA2_MODEL_REVISION,
device: str | torch.device = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.LORA,
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
) -> tuple[PaliGemmaProcessor, PaliGemmaForConditionalGeneration]:
"""Loads a PaliGemma 2 model and its associated processor.
Expand Down
13 changes: 13 additions & 0 deletions maestro/trainer/models/paligemma_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from maestro.trainer.common.utils.device import device_is_available, parse_device_spec
from maestro.trainer.common.utils.path import create_new_run_directory
from maestro.trainer.common.utils.seed import ensure_reproducibility
from maestro.trainer.logger import get_maestro_logger
from maestro.trainer.models.paligemma_2.checkpoints import (
DEFAULT_PALIGEMMA2_MODEL_ID,
DEFAULT_PALIGEMMA2_MODEL_REVISION,
Expand All @@ -27,6 +28,8 @@
from maestro.trainer.models.paligemma_2.inference import predict_with_inputs
from maestro.trainer.models.paligemma_2.loaders import evaluation_collate_fn, train_collate_fn

logger = get_maestro_logger()


@dataclass()
class PaliGemma2Configuration:
Expand Down Expand Up @@ -161,6 +164,12 @@ def validation_step(self, batch, batch_idx):
device=self.config.device,
max_new_tokens=self.config.max_new_tokens,
)

if batch_idx == 0:
logger.info(f"sample valid prefix: {prefixes[0]}")
logger.info(f"sample valid suffix: {suffixes[0]}")
logger.info(f"sample generated suffix: {generated_suffixes[0]}")

for metric in self.config.metrics:
result = metric.compute(predictions=generated_suffixes, targets=suffixes)
for key, value in result.items():
Expand Down Expand Up @@ -214,6 +223,10 @@ def train(config: PaliGemma2Configuration | dict) -> None:
test_num_workers=config.val_num_workers,
)

_, train_entry = train_loader.dataset[0]
logger.info(f"sample train prefix: {train_entry['prefix']}")
logger.info(f"sample train suffix: {train_entry['suffix']}")

pl_module = PaliGemma2Trainer(
processor=processor, model=model, train_loader=train_loader, valid_loader=valid_loader, config=config
)
Expand Down
4 changes: 3 additions & 1 deletion maestro/trainer/models/qwen_2_5_vl/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def load_model(
model_id_or_path: str = DEFAULT_QWEN2_5_VL_MODEL_ID,
revision: str = DEFAULT_QWEN2_5_VL_MODEL_REVISION,
device: str | torch.device = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.LORA,
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
min_pixels: int = 256 * 28 * 28,
max_pixels: int = 1280 * 28 * 28,
Expand Down Expand Up @@ -53,7 +53,9 @@ def load_model(
cache_dir=cache_dir,
min_pixels=min_pixels,
max_pixels=max_pixels,
use_fast=True,
)
processor.tokenizer.padding_side = "left"

if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}:
lora_config = LoraConfig(
Expand Down
88 changes: 77 additions & 11 deletions maestro/trainer/models/qwen_2_5_vl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,40 @@

import dacite
import lightning
import numpy as np
import supervision as sv
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor

from maestro.trainer.common.callbacks import SaveCheckpoint
from maestro.trainer.common.datasets.core import create_data_loaders, resolve_dataset_path
from maestro.trainer.common.metrics import BaseMetric, MetricsTracker, parse_metrics, save_metric_plots
from maestro.trainer.common.metrics import (
BaseMetric,
MeanAveragePrecisionMetric,
MetricsTracker,
parse_metrics,
save_metric_plots,
)
from maestro.trainer.common.training import MaestroTrainer
from maestro.trainer.common.utils.device import device_is_available, parse_device_spec
from maestro.trainer.common.utils.path import create_new_run_directory
from maestro.trainer.common.utils.seed import ensure_reproducibility
from maestro.trainer.logger import get_maestro_logger
from maestro.trainer.models.qwen_2_5_vl.checkpoints import (
DEFAULT_QWEN2_5_VL_MODEL_ID,
DEFAULT_QWEN2_5_VL_MODEL_REVISION,
OptimizationStrategy,
load_model,
save_model,
)
from maestro.trainer.models.qwen_2_5_vl.detection import detections_to_prefix_formatter, detections_to_suffix_formatter
from maestro.trainer.models.qwen_2_5_vl.inference import predict_with_inputs
from maestro.trainer.models.qwen_2_5_vl.loaders import evaluation_collate_fn, train_collate_fn

logger = get_maestro_logger()


@dataclass()
class Qwen25VLConfiguration:
Expand Down Expand Up @@ -137,7 +149,8 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
input_ids, attention_mask, pixel_values, image_grid_thw, prefixes, suffixes = batch
(input_ids, attention_mask, pixel_values, image_grid_thw, images, prefixes, suffixes) = batch
image_grid_thw_cpu = image_grid_thw.cpu()
generated_suffixes = predict_with_inputs(
model=self.model,
processor=self.processor,
Expand All @@ -148,16 +161,60 @@ def validation_step(self, batch, batch_idx):
device=self.config.device,
)

if batch_idx == 0:
logger.info(f"sample valid prefix: {prefixes[0]}")
logger.info(f"sample valid suffix: {suffixes[0]}")
logger.info(f"sample generated suffix: {generated_suffixes[0]}")

for metric in self.config.metrics:
result = metric.compute(predictions=generated_suffixes, targets=suffixes)
for key, value in result.items():
self.valid_metrics_tracker.register(
metric=key,
epoch=self.current_epoch,
step=batch_idx,
value=value,
)
self.log(key, value, prog_bar=True, logger=True)
if isinstance(metric, MeanAveragePrecisionMetric):
predictions_list = []
targets_list = []

for i, image in enumerate(images):
image_w, image_h = image.size
input_h = image_grid_thw_cpu[i][1] * 14
input_w = image_grid_thw_cpu[i][2] * 14

predictions = sv.Detections.from_vlm(
vlm=sv.VLM.QWEN_2_5_VL,
result=generated_suffixes[i],
input_wh=(input_w, input_h),
resolution_wh=(image_w, image_h),
)
predictions.class_id = np.full(len(predictions), fill_value=-1)
predictions.confidence = np.full(len(predictions), fill_value=1.0)

targets = sv.Detections.from_vlm(
vlm=sv.VLM.QWEN_2_5_VL,
result=suffixes[i],
input_wh=(input_w, input_h),
resolution_wh=(image_w, image_h),
)
targets.class_id = np.full(len(targets), fill_value=-1)

predictions_list.append(predictions)
targets_list.append(targets)

result = metric.compute(predictions=predictions_list, targets=targets_list)
for key, value in result.items():
self.valid_metrics_tracker.register(
metric=key,
epoch=self.current_epoch,
step=batch_idx,
value=value,
)
self.log(key, value, prog_bar=True, logger=True, batch_size=self.config.val_batch_size)
else:
result = metric.compute(predictions=generated_suffixes, targets=suffixes)
for key, value in result.items():
self.valid_metrics_tracker.register(
metric=key,
epoch=self.current_epoch,
step=batch_idx,
value=value,
)
self.log(key, value, prog_bar=True, logger=True, batch_size=self.config.val_batch_size)

def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.config.lr)
Expand Down Expand Up @@ -209,7 +266,16 @@ def train(config: Qwen25VLConfiguration | dict) -> None:
test_batch_size=config.val_batch_size,
test_collect_fn=partial(evaluation_collate_fn, processor=processor, system_message=config.system_message),
test_num_workers=config.val_num_workers,
detections_to_prefix_formatter=detections_to_prefix_formatter,
detections_to_suffix_formatter=partial(
detections_to_suffix_formatter, min_pixels=config.min_pixels, max_pixels=config.max_pixels
),
)
_, train_entry = train_loader.dataset[0]

logger.info(f"sample train prefix: {train_entry['prefix']}")
logger.info(f"sample train suffix: {train_entry['suffix']}")

pl_module = Qwen25VLTrainer(
processor=processor, model=model, train_loader=train_loader, valid_loader=valid_loader, config=config
)
Expand Down
35 changes: 35 additions & 0 deletions maestro/trainer/models/qwen_2_5_vl/detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from qwen_vl_utils import smart_resize


def detections_to_suffix_formatter(
xyxy: np.ndarray,
class_id: np.ndarray,
classes: list[str],
resolution_wh: tuple[int, int],
min_pixels: int,
max_pixels: int,
) -> str:
image_w, image_h = resolution_wh
input_h, input_w = smart_resize(height=image_h, width=image_w, min_pixels=min_pixels, max_pixels=max_pixels)

xyxy = xyxy / [image_w, image_h, image_w, image_h]
xyxy = xyxy * [input_w, input_h, input_w, input_h]
xyxy = xyxy.astype(int)

detection_lines = []
for cid, box in zip(class_id, xyxy):
label = classes[int(cid)]
bbox_str = ", ".join(str(num) for num in box.tolist())
line = f'\t{{"bbox_2d": [{bbox_str}], "label": "{label}"}}'
detection_lines.append(line)

joined_detections = ",\n".join(detection_lines)
formatted_str = f"```json\n[\n{joined_detections}\n]\n```"
return formatted_str


def detections_to_prefix_formatter(
xyxy: np.ndarray, class_id: np.ndarray, classes: list[str], resolution_wh: tuple[int, int]
) -> str:
return "Outline the position of " + ", ".join(classes) + ". Output all the coordinates in JSON format."
25 changes: 12 additions & 13 deletions maestro/trainer/models/qwen_2_5_vl/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,20 @@ def train_collate_fn(
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100

for i, conversation in enumerate(conversations):
# Ensure there is an assistant turn to separate out.
if len(conversation) < 2:
continue # Nothing to mask if there's no assistant turn.
sysuser_conv = conversation[:-1] # All turns except the assistant's turn.
sysuser_text = processor.apply_chat_template(conversation=sysuser_conv, tokenize=False)
sysuser_img, _ = process_vision_info(sysuser_conv)
sysuser_inputs = processor(
text=[sysuser_text],
images=[sysuser_img],
for conversation_index, complete_conversation in enumerate(conversations):
if len(complete_conversation) < 2:
continue
system_user_conversation = complete_conversation[:-1]
system_user_text = processor.apply_chat_template(conversation=system_user_conversation, tokenize=False)
system_user_image, _ = process_vision_info(system_user_conversation)
system_user_model_inputs = processor(
text=[system_user_text],
images=[system_user_image],
return_tensors="pt",
padding=True,
)
sysuser_len = sysuser_inputs["input_ids"].shape[1]
labels[i, :sysuser_len] = -100
system_user_input_length = system_user_model_inputs["input_ids"].shape[1]
labels[conversation_index, :system_user_input_length] = -100

input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
Expand Down Expand Up @@ -108,4 +107,4 @@ def evaluation_collate_fn(
pixel_values = model_inputs["pixel_values"]
image_grid_thw = model_inputs["image_grid_thw"]

return input_ids, attention_mask, pixel_values, image_grid_thw, prefixes, suffixes
return (input_ids, attention_mask, pixel_values, image_grid_thw, images, prefixes, suffixes)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "maestro"
version = "1.1.0rc2"
version = "1.1.0rc3"
description = "Streamline the fine-tuning process for vision-language models like PaliGemma 2, Florence-2, and Qwen2.5-VL."
authors = [
{name = "Piotr Skalski", email = "[email protected]"}
Expand Down Expand Up @@ -40,7 +40,7 @@ dependencies = [
"roboflow>=1.1.0",
"dacite>=1.9.1",
"lightning>=2.4.0",
"supervision>=0.20.0,<0.26.0",
"supervision>=0.26.0rc4",
"requests>=2.31.0,<=2.32.3",
"typer>=0.12.5",
"evaluate>=0.4.3",
Expand Down

0 comments on commit f130403

Please sign in to comment.