diff --git a/.github/workflows/integration_tests_workflows_x86.yml b/.github/workflows/integration_tests_workflows_x86.yml
index 28c151100..31cc23c43 100644
--- a/.github/workflows/integration_tests_workflows_x86.yml
+++ b/.github/workflows/integration_tests_workflows_x86.yml
@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
- timeout-minutes: 10
+ timeout-minutes: 15
steps:
- name: ๐๏ธ Checkout
uses: actions/checkout@v4
@@ -30,6 +30,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --upgrade setuptools
- pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements/_requirements.txt -r requirements/requirements.cpu.txt -r requirements/requirements.sdk.http.txt -r requirements/requirements.test.unit.txt -r requirements/requirements.http.txt -r requirements/requirements.yolo_world.txt -r requirements/requirements.doctr.txt -r requirements/requirements.sam.txt
+ pip install --extra-index-url https://download.pytorch.org/whl/cpu -r requirements/_requirements.txt -r requirements/requirements.cpu.txt -r requirements/requirements.sdk.http.txt -r requirements/requirements.test.unit.txt -r requirements/requirements.http.txt -r requirements/requirements.yolo_world.txt -r requirements/requirements.doctr.txt -r requirements/requirements.sam.txt -r requirements/requirements.transformers.txt
- name: ๐งช Integration Tests of Workflows
- run: ROBOFLOW_API_KEY=${{ secrets.API_KEY }} python -m pytest tests/workflows/integration_tests
+ run: ROBOFLOW_API_KEY=${{ secrets.API_KEY }} SKIP_FLORENCE2_TEST=FALSE python -m pytest tests/workflows/integration_tests
diff --git a/docker/dockerfiles/Dockerfile.onnx.gpu.dev b/docker/dockerfiles/Dockerfile.onnx.gpu.dev
new file mode 100644
index 000000000..e483abdff
--- /dev/null
+++ b/docker/dockerfiles/Dockerfile.onnx.gpu.dev
@@ -0,0 +1,82 @@
+FROM nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 as base
+
+WORKDIR /app
+
+RUN rm -rf /var/lib/apt/lists/* && apt-get clean && apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ ffmpeg \
+ libxext6 \
+ libopencv-dev \
+ uvicorn \
+ python3-pip \
+ git \
+ libgdal-dev \
+ wget \
+ && rm -rf /var/lib/apt/lists/*
+
+COPY requirements/requirements.sam.txt \
+ requirements/requirements.clip.txt \
+ requirements/requirements.http.txt \
+ requirements/requirements.gpu.txt \
+ requirements/requirements.waf.txt \
+ requirements/requirements.gaze.txt \
+ requirements/requirements.doctr.txt \
+ requirements/requirements.groundingdino.txt \
+ requirements/requirements.cogvlm.txt \
+ requirements/requirements.yolo_world.txt \
+ requirements/_requirements.txt \
+ requirements/requirements.transformers.txt \
+ requirements/requirements.pali.flash_attn.txt \
+ requirements/requirements.sdk.http.txt \
+ requirements/requirements.cli.txt \
+ ./
+
+RUN python3 -m pip install -U pip
+RUN python3 -m pip install --extra-index-url https://download.pytorch.org/whl/cu118 \
+ -r _requirements.txt \
+ -r requirements.sam.txt \
+ -r requirements.clip.txt \
+ -r requirements.http.txt \
+ -r requirements.gpu.txt \
+ -r requirements.waf.txt \
+ -r requirements.gaze.txt \
+ -r requirements.groundingdino.txt \
+ -r requirements.doctr.txt \
+ -r requirements.cogvlm.txt \
+ -r requirements.yolo_world.txt \
+ -r requirements.transformers.txt \
+ -r requirements.sdk.http.txt \
+ -r requirements.cli.txt \
+ jupyterlab \
+ --upgrade \
+ && rm -rf ~/.cache/pip
+
+# Install setup.py requirements for flash_attn
+RUN python3 -m pip install packaging==24.1 && rm -rf ~/.cache/pip
+
+# Install flash_attn required for Paligemma and Florence2
+RUN python3 -m pip install -r requirements.pali.flash_attn.txt --no-build-isolation && rm -rf ~/.cache/pip
+
+FROM scratch
+COPY --from=base / /
+
+WORKDIR /app/
+COPY inference inference
+COPY inference_sdk inference_sdk
+COPY inference_cli inference_cli
+ENV PYTHONPATH=/app/
+COPY docker/config/gpu_http.py gpu_http.py
+
+ENV PYTHONPATH=/app/
+ENV VERSION_CHECK_MODE=continuous
+ENV PROJECT=roboflow-platform
+ENV NUM_WORKERS=1
+ENV HOST=0.0.0.0
+ENV PORT=9001
+ENV WORKFLOWS_STEP_EXECUTION_MODE=local
+ENV WORKFLOWS_MAX_CONCURRENT_STEPS=1
+ENV API_LOGGING_ENABLED=True
+ENV LMM_ENABLED=True
+ENV CORE_MODEL_SAM2_ENABLED=True
+ENV CORE_MODEL_OWLV2_ENABLED=True
+
+ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
\ No newline at end of file
diff --git a/docs/workflows/blocks.md b/docs/workflows/blocks.md
index ed2f6eccc..2be691ce4 100644
--- a/docs/workflows/blocks.md
+++ b/docs/workflows/blocks.md
@@ -13,6 +13,8 @@ hide:
+
+
@@ -33,6 +35,7 @@ hide:
+
@@ -58,6 +61,7 @@ hide:
+
@@ -77,6 +81,9 @@ hide:
+
+
+
diff --git a/docs/workflows/gallery_index.md b/docs/workflows/gallery_index.md
index f9dd8a1a6..81f47cc4c 100644
--- a/docs/workflows/gallery_index.md
+++ b/docs/workflows/gallery_index.md
@@ -6,9 +6,9 @@ Browse through the various categories to find inspiration and ideas for building
- Workflows with multiple models
- Workflows enhanced by Roboflow Platform
+ - Basic Workflows
- Workflows with classical Computer Vision methods
- Workflows with Visual Language Models
- - Basic Workflows
- Workflows with dynamic Python Blocks
- Workflows with data transformations
- Workflows with flow control
diff --git a/docs/workflows/kinds.md b/docs/workflows/kinds.md
index 22482dfd8..5458c3c11 100644
--- a/docs/workflows/kinds.md
+++ b/docs/workflows/kinds.md
@@ -37,36 +37,36 @@ for the presence of a mask in the input.
## Kinds declared in Roboflow plugins
+* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object
+* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any type
+* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction
+* [`zone`](/workflows/kinds/zone): Definition of polygon zone
+* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method
+* [`serialised_payloads`](/workflows/kinds/serialised_payloads): Serialised element that is usually accepted by sink
+* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`)
* [`bar_code_detection`](/workflows/kinds/bar_code_detection): Prediction with barcode detection
-* [`language_model_output`](/workflows/kinds/language_model_output): LLM / VLM output
+* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata
+* [`rgb_color`](/workflows/kinds/rgb_color): RGB color
+* [`float`](/workflows/kinds/float): Float value
* [`top_class`](/workflows/kinds/top_class): String value representing top class predicted by classification model
-* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction
-* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object
-* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection
* [`image_metadata`](/workflows/kinds/image_metadata): Dictionary with image metadata required by supervision
+* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array
+* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id
+* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key
+* [`integer`](/workflows/kinds/integer): Integer value
+* [`boolean`](/workflows/kinds/boolean): Boolean flag
+* [`language_model_output`](/workflows/kinds/language_model_output): LLM / VLM output
+* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection
+* [`point`](/workflows/kinds/point): Single point in 2D
* [`float_zero_to_one`](/workflows/kinds/float_zero_to_one): `float` value in range `[0.0, 1.0]`
+* [`dictionary`](/workflows/kinds/dictionary): Dictionary
* [`parent_id`](/workflows/kinds/parent_id): Identifier of parent for step output
-* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object
-* [`float`](/workflows/kinds/float): Float value
-* [`*`](/workflows/kinds/*): Equivalent of any element
* [`contours`](/workflows/kinds/contours): List of numpy arrays where each array represents contour points
-* [`boolean`](/workflows/kinds/boolean): Boolean flag
-* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`)
-* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name
-* [`dictionary`](/workflows/kinds/dictionary): Dictionary
-* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array
-* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key
* [`string`](/workflows/kinds/string): String value
-* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id
-* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any types
-* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object
+* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object
+* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name
* [`image`](/workflows/kinds/image): Image in workflows
-* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata
-* [`serialised_payloads`](/workflows/kinds/serialised_payloads): Serialised element that is usually accepted by sink
-* [`integer`](/workflows/kinds/integer): Integer value
-* [`rgb_color`](/workflows/kinds/rgb_color): RGB color
+* [`*`](/workflows/kinds/*): Equivalent of any element
* [`classification_prediction`](/workflows/kinds/classification_prediction): Predictions from classifier
-* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method
-* [`point`](/workflows/kinds/point): Single point in 2D
-* [`zone`](/workflows/kinds/zone): Definition of polygon zone
+* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object
diff --git a/inference/core/entities/responses/inference.py b/inference/core/entities/responses/inference.py
index 298641cea..592a99787 100644
--- a/inference/core/entities/responses/inference.py
+++ b/inference/core/entities/responses/inference.py
@@ -291,7 +291,9 @@ class MultiLabelClassificationInferenceResponse(
class LMMInferenceResponse(CvInferenceResponse):
- response: str = Field(description="Text generated by PaliGemma")
+ response: Union[str, dict] = Field(
+ description="Text/structured response generated by model"
+ )
class FaceDetectionPrediction(ObjectDetectionPrediction):
diff --git a/inference/core/version.py b/inference/core/version.py
index 2ffeef758..a54b5b3ab 100644
--- a/inference/core/version.py
+++ b/inference/core/version.py
@@ -1,4 +1,4 @@
-__version__ = "0.19.0"
+__version__ = "0.20.0"
if __name__ == "__main__":
diff --git a/inference/core/workflows/core_steps/formatters/vlm_as_detector/v1.py b/inference/core/workflows/core_steps/formatters/vlm_as_detector/v1.py
index 3dbb7cf3d..88056660d 100644
--- a/inference/core/workflows/core_steps/formatters/vlm_as_detector/v1.py
+++ b/inference/core/workflows/core_steps/formatters/vlm_as_detector/v1.py
@@ -1,6 +1,8 @@
+import hashlib
import json
import logging
import re
+from functools import partial
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
from uuid import uuid4
@@ -87,20 +89,37 @@ class BlockManifest(WorkflowBlockManifest):
description="The string with raw classification prediction to parse.",
examples=[["$steps.lmm.output"]],
)
- classes: Union[
- WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]),
- StepOutputSelector(kind=[LIST_OF_VALUES_KIND]),
- List[str],
+ classes: Optional[
+ Union[
+ WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]),
+ StepOutputSelector(kind=[LIST_OF_VALUES_KIND]),
+ List[str],
+ ]
] = Field(
description="List of all classes used by the model, required to "
"generate mapping between class name and class id.",
examples=[["$steps.lmm.classes", "$inputs.classes", ["class_a", "class_b"]]],
+ json_schema_extra={
+ "relevant_for": {
+ "model_type": {
+ "values": ["google-gemini", "anthropic-claude"],
+ "required": True,
+ },
+ }
+ },
)
- model_type: Literal["google-gemini", "anthropic-claude"] = Field(
+ model_type: Literal["google-gemini", "anthropic-claude", "florence-2"] = Field(
description="Type of the model that generated prediction",
- examples=[["google-gemini", "anthropic-claude"]],
+ examples=[["google-gemini", "anthropic-claude", "florence-2"]],
)
- task_type: Literal["object-detection"]
+ task_type: Literal[
+ "object-detection",
+ "object-detection-and-caption",
+ "open-vocabulary-object-detection",
+ "phrase-grounded-object-detection",
+ "region-proposal",
+ "ocr-with-text-detection",
+ ]
@model_validator(mode="after")
def validate(self) -> "BlockManifest":
@@ -108,6 +127,11 @@ def validate(self) -> "BlockManifest":
raise ValueError(
f"Could not parse result of task {self.task_type} for model {self.model_type}"
)
+ if self.model_type != "florence-2" and self.classes is None:
+ raise ValueError(
+ "Must pass list of classes to this block when using gemini or claude"
+ )
+
return self
@classmethod
@@ -135,7 +159,7 @@ def run(
self,
image: WorkflowImageData,
vlm_output: str,
- classes: List[str],
+ classes: Optional[List[str]],
model_type: str,
task_type: str,
) -> BlockResult:
@@ -255,7 +279,88 @@ def scale_confidence(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
+def parse_florence2_object_detection_response(
+ image: WorkflowImageData,
+ parsed_data: dict,
+ classes: Optional[List[str]],
+ inference_id: str,
+ florence_task_type: str,
+):
+ image_height, image_width = image.numpy_image.shape[:2]
+ detections = sv.Detections.from_lmm(
+ "florence_2",
+ result={florence_task_type: parsed_data},
+ resolution_wh=(image_width, image_height),
+ )
+ detections.class_id = np.array([0] * len(detections))
+ if florence_task_type == "":
+ detections.data["class_name"] = np.array(["roi"] * len(detections))
+ if florence_task_type in {"", ""}:
+ unique_class_names = set(detections.data.get("class_name", []))
+ class_name_to_id = {
+ name: get_4digit_from_md5(name) for name in unique_class_names
+ }
+ class_ids = [
+ class_name_to_id.get(name, -1)
+ for name in detections.data.get("class_name", ["unknown"] * len(detections))
+ ]
+ detections.class_id = np.array(class_ids)
+ if florence_task_type in "":
+ class_name_to_id = {name: idx for idx, name in enumerate(classes)}
+ class_ids = [
+ class_name_to_id.get(name, -1)
+ for name in detections.data.get("class_name", ["unknown"] * len(detections))
+ ]
+ detections.class_id = np.array(class_ids)
+ dimensions = np.array([[image_height, image_width]] * len(detections))
+ detection_ids = np.array([str(uuid4()) for _ in range(len(detections))])
+ inference_ids = np.array([inference_id] * len(detections))
+ prediction_type = np.array(["object-detection"] * len(detections))
+ detections.data.update(
+ {
+ INFERENCE_ID_KEY: inference_ids,
+ DETECTION_ID_KEY: detection_ids,
+ PREDICTION_TYPE_KEY: prediction_type,
+ IMAGE_DIMENSIONS_KEY: dimensions,
+ }
+ )
+ detections.confidence = np.array([1.0 for _ in detections])
+ return attach_parents_coordinates_to_sv_detections(
+ detections=detections, image=image
+ )
+
+
+def get_4digit_from_md5(input_string):
+ md5_hash = hashlib.md5(input_string.encode("utf-8"))
+ hex_digest = md5_hash.hexdigest()
+ integer_value = int(hex_digest[:9], 16)
+ return integer_value % 10000
+
+
REGISTERED_PARSERS = {
("google-gemini", "object-detection"): parse_gemini_object_detection_response,
("anthropic-claude", "object-detection"): parse_gemini_object_detection_response,
+ ("florence-2", "object-detection"): partial(
+ parse_florence2_object_detection_response, florence_task_type=""
+ ),
+ ("florence-2", "open-vocabulary-object-detection"): partial(
+ parse_florence2_object_detection_response,
+ florence_task_type="",
+ ),
+ ("florence-2", "object-detection-and-caption"): partial(
+ parse_florence2_object_detection_response,
+ florence_task_type="",
+ ),
+ ("florence-2", "phrase-grounded-object-detection"): partial(
+ parse_florence2_object_detection_response,
+ florence_task_type="",
+ ),
+ ("florence-2", "region-proposal"): partial(
+ parse_florence2_object_detection_response,
+ florence_task_type="",
+ ),
+ ("florence-2", "ocr-with-text-detection"): partial(
+ parse_florence2_object_detection_response,
+ florence_task_type="",
+ ),
}
diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py
index 049646b90..5b59fff09 100644
--- a/inference/core/workflows/core_steps/loader.py
+++ b/inference/core/workflows/core_steps/loader.py
@@ -85,6 +85,9 @@
from inference.core.workflows.core_steps.models.foundation.cog_vlm.v1 import (
CogVLMBlockV1,
)
+from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
+ Florence2BlockV1,
+)
from inference.core.workflows.core_steps.models.foundation.google_gemini.v1 import (
GoogleGeminiBlockV1,
)
@@ -343,6 +346,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
AntropicClaudeBlockV1,
LineCounterBlockV1,
PolygonZoneVisualizationBlockV1,
+ Florence2BlockV1,
]
diff --git a/inference/core/workflows/core_steps/models/foundation/anthropic_claude/v1.py b/inference/core/workflows/core_steps/models/foundation/anthropic_claude/v1.py
index 221b1f591..885390486 100644
--- a/inference/core/workflows/core_steps/models/foundation/anthropic_claude/v1.py
+++ b/inference/core/workflows/core_steps/models/foundation/anthropic_claude/v1.py
@@ -127,7 +127,10 @@ class BlockManifest(WorkflowBlockManifest):
examples=[{"my_key": "description"}, "$inputs.output_structure"],
json_schema_extra={
"relevant_for": {
- "task_type": {"values": TASKS_REQUIRING_CLASSES, "required": True},
+ "task_type": {
+ "values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
+ "required": True,
+ },
},
},
)
@@ -140,7 +143,7 @@ class BlockManifest(WorkflowBlockManifest):
json_schema_extra={
"relevant_for": {
"task_type": {
- "values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
+ "values": TASKS_REQUIRING_CLASSES,
"required": True,
},
},
diff --git a/inference/core/workflows/core_steps/models/foundation/florence2/__init__.py b/inference/core/workflows/core_steps/models/foundation/florence2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/inference/core/workflows/core_steps/models/foundation/florence2/v1.py b/inference/core/workflows/core_steps/models/foundation/florence2/v1.py
new file mode 100644
index 000000000..22d2c0adb
--- /dev/null
+++ b/inference/core/workflows/core_steps/models/foundation/florence2/v1.py
@@ -0,0 +1,517 @@
+import json
+from typing import List, Literal, Optional, Type, TypeVar, Union
+
+import numpy as np
+import supervision as sv
+from pydantic import ConfigDict, Field, model_validator
+
+from inference.core.entities.requests.inference import LMMInferenceRequest
+from inference.core.managers.base import ModelManager
+from inference.core.workflows.core_steps.common.entities import StepExecutionMode
+from inference.core.workflows.execution_engine.entities.base import (
+ Batch,
+ OutputDefinition,
+ WorkflowImageData,
+)
+from inference.core.workflows.execution_engine.entities.types import (
+ DICTIONARY_KIND,
+ INSTANCE_SEGMENTATION_PREDICTION_KIND,
+ KEYPOINT_DETECTION_PREDICTION_KIND,
+ LANGUAGE_MODEL_OUTPUT_KIND,
+ LIST_OF_VALUES_KIND,
+ OBJECT_DETECTION_PREDICTION_KIND,
+ STRING_KIND,
+ ImageInputField,
+ StepOutputImageSelector,
+ StepOutputSelector,
+ WorkflowImageSelector,
+ WorkflowParameterSelector,
+)
+from inference.core.workflows.prototypes.block import (
+ BlockResult,
+ WorkflowBlock,
+ WorkflowBlockManifest,
+)
+
+T = TypeVar("T")
+K = TypeVar("K")
+
+DETECTIONS_CLASS_NAME_FIELD = "class_name"
+DETECTION_ID_FIELD = "detection_id"
+
+LONG_DESCRIPTION = """
+**Dedicated inference server required (GPU recommended) - you may want to use dedicated deployment**
+
+This Workflow block introduces **Florence 2**, a Visual Language Model (VLM) capable of performing a
+wide range of tasks, including:
+
+* Object Detection
+
+* Instance Segmentation
+
+* Image Captioning
+
+* Optical Character Recognition (OCR)
+
+* and more...
+
+
+Below is a comprehensive list of tasks supported by the model, along with descriptions on
+how to utilize their outputs within the Workflows ecosystem:
+
+**Task Descriptions:**
+
+* `ocr`: Performs Optical Character Recognition on the entire image. The `raw_output` contains the extracted text
+without any structure.
+
+* `ocr-with-text-detection`: Detects text regions in the image, and then performs OCR on each detected region.
+Both `raw_output` and `classes` should be connected to the VLM as a Detector block, producing `sv.Detections(...)`
+where labels represent the recognized text.
+
+* `caption` / `detailed-caption` / `more-detailed-caption`: Generates image captions. The resulting caption is
+available in `raw_output`.
+
+* `object-detection`: Detects visible objects in the image. The model uses a predefined set of classes. To use
+the output, plug `raw_output` and `classes` into the VLM Detector block, yielding `sv.Detections(...)`.
+
+* `open-vocabulary-object-detection`: Allows control over the set of objects being detected by specifying a list of
+classes as input. Ensure that the classes are visible in the image for optimal performance. The outputs
+(`raw_output` and `classes`) can be connected to the VLM Detector block to obtain `sv.Detections(...)`.
+
+* `phrase-grounded-object-detection`: Provides a textual prompt describing the objects of interest (multiple
+objects can be specified). Florence will detect the corresponding objects. Connect `raw_output` and `classes`
+to the VLM Detector block to get `sv.Detections(...)`.
+
+* `phrase-grounded-instance-segmentation`: Uses a textual prompt to detect and segment objects of interest.
+Florence will perform instance segmentation based on the specified descriptions.
+
+* `detection-grounded-instance-segmentation`: Performs semantic segmentation within a provided region of interest
+(RoI). The RoI can be specified as an input parameter or generated by an upstream detection model. If the latter,
+the `grounding_selection_mode` must be defined to choose a single bounding box as grounding, as Florence 2 currently
+performs best when only one box is provided.
+
+* `detection-grounded-classification`: Classifies a specific region of interest (RoI). Similar to instance
+segmentation, the RoI can be input directly or generated by an upstream model. Use `grounding_selection_mode` to
+specify how to select a single bounding box for grounding.
+
+* `detection-grounded-caption`: Generates a caption for a specific region of interest (RoI). The RoI can be
+provided as input or come from an upstream detection model. As with other tasks, use `grounding_selection_mode`
+to define how to select a single bounding box.
+
+* `detection-grounded-ocr`: Performs OCR on a specific region of interest (RoI). RoI can be manually provided or
+sourced from an upstream detection model. Again, use `grounding_selection_mode` to choose one bounding box for grounding.
+
+* `region-proposal`: Automatically proposes regions of interest in the image. `raw_output` and `classes`
+should be plugged into the VLM Detector block to generate `sv.Detections(...)`.
+"""
+
+TASK_TYPE_TO_FLORENCE_TASK = {
+ "ocr": "",
+ "ocr-with-text-detection": "",
+ "caption": "",
+ "detailed-caption": "",
+ "more-detailed-caption": "",
+ "object-detection-and-caption": "",
+ "object-detection": "",
+ "open-vocabulary-object-detection": "",
+ "phrase-grounded-object-detection": "",
+ "phrase-grounded-instance-segmentation": "",
+ "detection-grounded-instance-segmentation": "",
+ "detection-grounded-classification": "",
+ "detection-grounded-caption": "",
+ "detection-grounded-ocr": "",
+ "region-proposal": "",
+}
+TaskType = Literal[tuple(TASK_TYPE_TO_FLORENCE_TASK.keys())]
+GroundingSelectionMode = Literal[
+ "first",
+ "last",
+ "biggest",
+ "smallest",
+ "most-confident",
+ "least-confident",
+]
+
+TASKS_REQUIRING_PROMPT = {
+ "phrase-grounded-object-detection",
+ "phrase-grounded-instance-segmentation",
+}
+TASKS_REQUIRING_CLASSES = {
+ "open-vocabulary-object-detection",
+}
+TASKS_REQUIRING_DETECTION_GROUNDING = {
+ "detection-grounded-instance-segmentation",
+ "detection-grounded-classification",
+ "detection-grounded-caption",
+ "detection-grounded-ocr",
+}
+LOC_BINS = 1000
+
+TASKS_TO_EXTRACT_LABELS_AS_CLASSES = {
+ "",
+ "",
+ "",
+ "",
+}
+
+
+class BlockManifest(WorkflowBlockManifest):
+ model_config = ConfigDict(
+ json_schema_extra={
+ "name": "Florence-2 Model",
+ "version": "v1",
+ "short_description": "Run Florence-2 on an image",
+ "long_description": LONG_DESCRIPTION,
+ "license": "Apache-2.0",
+ "block_type": "model",
+ "search_keywords": ["Florence", "Florence-2", "Microsoft"],
+ },
+ protected_namespaces=(),
+ )
+ type: Literal["roboflow_core/florence_2@v1"]
+ images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField
+ model_version: Union[
+ WorkflowParameterSelector(kind=[STRING_KIND]),
+ Literal["florence-2-base", "florence-2-large"],
+ ] = Field(
+ default="florence-2-base",
+ description="Model to be used",
+ examples=["florence-2-base"],
+ )
+ task_type: TaskType = Field(
+ description="Task type to be performed by model. "
+ "Value determines required parameters and output response."
+ )
+ prompt: Optional[Union[WorkflowParameterSelector(kind=[STRING_KIND]), str]] = Field(
+ default=None,
+ description="Text prompt to the Florence-2 model",
+ examples=["my prompt", "$inputs.prompt"],
+ json_schema_extra={
+ "relevant_for": {
+ "task_type": {"values": TASKS_REQUIRING_PROMPT, "required": True},
+ },
+ },
+ )
+ classes: Optional[
+ Union[WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]), List[str]]
+ ] = Field(
+ default=None,
+ description="List of classes to be used",
+ examples=[["class-a", "class-b"], "$inputs.classes"],
+ json_schema_extra={
+ "relevant_for": {
+ "task_type": {
+ "values": TASKS_REQUIRING_CLASSES,
+ "required": True,
+ },
+ },
+ },
+ )
+ grounding_detection: Optional[
+ Union[
+ List[int],
+ List[float],
+ StepOutputSelector(
+ kind=[
+ OBJECT_DETECTION_PREDICTION_KIND,
+ INSTANCE_SEGMENTATION_PREDICTION_KIND,
+ KEYPOINT_DETECTION_PREDICTION_KIND,
+ ]
+ ),
+ WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]),
+ ]
+ ] = Field(
+ default=None,
+ description="Detection to ground Florence-2 model. May be statically provided bounding box "
+ "`[left_top_x, left_top_y, right_bottom_x, right_bottom_y]` or result of object-detection model. "
+ "If the latter is true, one box will be selected based on `grounding_selection_mode`.",
+ examples=["$steps.detection.predictions", [10, 20, 30, 40]],
+ json_schema_extra={
+ "relevant_for": {
+ "task_type": {
+ "values": TASKS_REQUIRING_DETECTION_GROUNDING,
+ "required": True,
+ },
+ },
+ },
+ )
+ grounding_selection_mode: GroundingSelectionMode = Field(
+ default="first",
+ description="",
+ examples=["first", "most-confident"],
+ json_schema_extra={
+ "relevant_for": {
+ "task_type": {
+ "values": TASKS_REQUIRING_DETECTION_GROUNDING,
+ "required": True,
+ },
+ },
+ },
+ )
+
+ @classmethod
+ def accepts_batch_input(cls) -> bool:
+ return True
+
+ @model_validator(mode="after")
+ def validate(self) -> "BlockManifest":
+ if self.task_type in TASKS_REQUIRING_PROMPT and self.prompt is None:
+ raise ValueError(
+ f"`prompt` parameter required to be set for task `{self.task_type}`"
+ )
+ if self.task_type in TASKS_REQUIRING_CLASSES and not self.classes:
+ raise ValueError(
+ f"`classes` parameter required to be set for task `{self.task_type}`"
+ )
+ if (
+ self.task_type in TASKS_REQUIRING_DETECTION_GROUNDING
+ and not self.grounding_detection
+ ):
+ raise ValueError(
+ f"`grounding_detection` parameter required to be set for task `{self.task_type}`"
+ )
+ return self
+
+ @classmethod
+ def describe_outputs(cls) -> List[OutputDefinition]:
+ return [
+ OutputDefinition(
+ name="raw_output", kind=[STRING_KIND, LANGUAGE_MODEL_OUTPUT_KIND]
+ ),
+ OutputDefinition(name="parsed_output", kind=[DICTIONARY_KIND]),
+ OutputDefinition(name="classes", kind=[LIST_OF_VALUES_KIND]),
+ ]
+
+ @classmethod
+ def get_execution_engine_compatibility(cls) -> Optional[str]:
+ return ">=1.0.0,<2.0.0"
+
+
+class Florence2BlockV1(WorkflowBlock):
+
+ def __init__(
+ self,
+ model_manager: ModelManager,
+ api_key: Optional[str],
+ step_execution_mode: StepExecutionMode,
+ ):
+ self._model_manager = model_manager
+ self._api_key = api_key
+ self._step_execution_mode = step_execution_mode
+
+ @classmethod
+ def get_init_parameters(cls) -> List[str]:
+ return ["model_manager", "api_key", "step_execution_mode"]
+
+ @classmethod
+ def get_manifest(cls) -> Type[WorkflowBlockManifest]:
+ return BlockManifest
+
+ def run(
+ self,
+ images: Batch[WorkflowImageData],
+ model_version: str,
+ task_type: TaskType,
+ prompt: Optional[str],
+ classes: Optional[List[str]],
+ grounding_detection: Optional[
+ Union[Batch[sv.Detections], List[int], List[float]]
+ ],
+ grounding_selection_mode: GroundingSelectionMode,
+ ) -> BlockResult:
+ if self._step_execution_mode is StepExecutionMode.LOCAL:
+ return self.run_locally(
+ images=images,
+ task_type=task_type,
+ model_version=model_version,
+ prompt=prompt,
+ classes=classes,
+ grounding_detection=grounding_detection,
+ grounding_selection_mode=grounding_selection_mode,
+ )
+ elif self._step_execution_mode is StepExecutionMode.REMOTE:
+ raise NotImplementedError(
+ "Remote execution is not supported for florence2. Run a local or dedicated inference server to use this block (GPU recommended)."
+ )
+ else:
+ raise ValueError(
+ f"Unknown step execution mode: {self._step_execution_mode}"
+ )
+
+ def run_locally(
+ self,
+ images: Batch[WorkflowImageData],
+ model_version: str,
+ task_type: TaskType,
+ prompt: Optional[str],
+ classes: Optional[List[str]],
+ grounding_detection: Optional[
+ Union[Batch[sv.Detections], List[int], List[float]]
+ ],
+ grounding_selection_mode: GroundingSelectionMode,
+ ) -> BlockResult:
+ requires_detection_grounding = task_type in TASKS_REQUIRING_DETECTION_GROUNDING
+ task_type = TASK_TYPE_TO_FLORENCE_TASK[task_type]
+ inference_images = [
+ i.to_inference_format(numpy_preferred=False) for i in images
+ ]
+ prompts = [prompt] * len(images)
+ if classes is not None:
+ prompts = ["".join(classes)] * len(images)
+ else:
+ classes = []
+ if grounding_detection is not None:
+ prompts = prepare_detection_grounding_prompts(
+ images=images,
+ grounding_detection=grounding_detection,
+ grounding_selection_mode=grounding_selection_mode,
+ )
+ self._model_manager.add_model(
+ model_id=model_version,
+ api_key=self._api_key,
+ )
+ predictions = []
+ for image, single_prompt in zip(inference_images, prompts):
+ if single_prompt is None and requires_detection_grounding:
+ # no grounding bbox found - empty result returned
+ predictions.append(
+ {"raw_output": None, "parsed_output": None, "classes": None}
+ )
+ continue
+ request = LMMInferenceRequest(
+ api_key=self._api_key,
+ model_id=model_version,
+ image=image,
+ source="workflow-execution",
+ prompt=task_type + (single_prompt or ""),
+ )
+ prediction = self._model_manager.infer_from_request_sync(
+ model_id=model_version, request=request
+ )
+ prediction_data = prediction.response[task_type]
+ if task_type in TASKS_TO_EXTRACT_LABELS_AS_CLASSES:
+ classes = prediction_data.get("labels", [])
+ predictions.append(
+ {
+ "raw_output": json.dumps(prediction_data),
+ "parsed_output": (
+ prediction_data if isinstance(prediction_data, dict) else None
+ ),
+ "classes": classes,
+ }
+ )
+ return predictions
+
+
+def prepare_detection_grounding_prompts(
+ images: Batch[WorkflowImageData],
+ grounding_detection: Union[Batch[sv.Detections], List[float], List[int]],
+ grounding_selection_mode: GroundingSelectionMode,
+) -> List[Optional[str]]:
+ if isinstance(grounding_detection, list):
+ return _prepare_grounding_bounding_box_from_coordinates(
+ images=images,
+ bounding_box=grounding_detection,
+ )
+ return [
+ _prepare_grounding_bounding_box_from_detections(
+ image=image.numpy_image,
+ detections=detections,
+ grounding_selection_mode=grounding_selection_mode,
+ )
+ for image, detections in zip(images, grounding_detection)
+ ]
+
+
+def _prepare_grounding_bounding_box_from_coordinates(
+ images: Batch[WorkflowImageData], bounding_box: Union[List[float], List[int]]
+) -> List[str]:
+ return [
+ _extract_bbox_coordinates_as_location_prompt(
+ image=image.numpy_image, bounding_box=bounding_box
+ )
+ for image in images
+ ]
+
+
+def _prepare_grounding_bounding_box_from_detections(
+ image: np.ndarray,
+ detections: sv.Detections,
+ grounding_selection_mode: GroundingSelectionMode,
+) -> Optional[str]:
+ if len(detections) == 0:
+ return None
+ height, width = image.shape[:2]
+ if grounding_selection_mode not in COORDINATES_EXTRACTION:
+ raise ValueError(
+ f"Unknown grounding selection mode: {grounding_selection_mode}"
+ )
+ extraction_function = COORDINATES_EXTRACTION[grounding_selection_mode]
+ left_top_x, left_top_y, right_bottom_x, right_bottom_y = extraction_function(
+ detections
+ )
+ left_top_x = _coordinate_to_loc(value=left_top_x / width)
+ left_top_y = _coordinate_to_loc(value=left_top_y / height)
+ right_bottom_x = _coordinate_to_loc(value=right_bottom_x / width)
+ right_bottom_y = _coordinate_to_loc(value=right_bottom_y / height)
+ return f""
+
+
+COORDINATES_EXTRACTION = {
+ "first": lambda detections: detections.xyxy[0].tolist(),
+ "last": lambda detections: detections.xyxy[0].tolist(),
+ "biggest": lambda detections: detections.xyxy[np.argmax(detections.area)].tolist(),
+ "smallest": lambda detections: detections.xyxy[np.argmin(detections.area)].tolist(),
+ "most-confident": lambda detections: detections.xyxy[
+ np.argmax(detections.confidence)
+ ].tolist(),
+ "least-confident": lambda detections: detections.xyxy[
+ np.argmin(detections.confidence)
+ ].tolist(),
+}
+
+
+def _extract_bbox_coordinates_as_location_prompt(
+ image: np.ndarray,
+ bounding_box: Union[List[float], List[int]],
+) -> str:
+ height, width = image.shape[:2]
+ coordinates = bounding_box[:4]
+ if len(coordinates) != 4:
+ raise ValueError(
+ "Could not extract 4 coordinates of bounding box to perform detection "
+ "grounded Florence 2 prediction."
+ )
+ left_top_x, left_top_y, right_bottom_x, right_bottom_y = coordinates
+ if all(isinstance(c, float) for c in coordinates):
+ left_top_x = _coordinate_to_loc(value=left_top_x)
+ left_top_y = _coordinate_to_loc(value=left_top_y)
+ right_bottom_x = _coordinate_to_loc(value=right_bottom_x)
+ right_bottom_y = _coordinate_to_loc(value=right_bottom_y)
+ return f""
+ if all(isinstance(c, int) for c in coordinates):
+ left_top_x = _coordinate_to_loc(value=left_top_x / width)
+ left_top_y = _coordinate_to_loc(value=left_top_y / height)
+ right_bottom_x = _coordinate_to_loc(value=right_bottom_x / width)
+ right_bottom_y = _coordinate_to_loc(value=right_bottom_y / height)
+ return f""
+ raise ValueError(
+ "Provided coordinates in mixed format - coordinates must be all integers or all floats in range [0.0-1.0]"
+ )
+
+
+def _coordinate_to_loc(value: float) -> int:
+ loc_bin = round(_scale_value(value=value, min_value=0.0, max_value=1.0) * LOC_BINS)
+ return _scale_value( # to make sure 0-999 cutting out 1000 on 1.0
+ value=loc_bin,
+ min_value=0,
+ max_value=LOC_BINS - 1,
+ )
+
+
+def _scale_value(
+ value: Union[int, float],
+ min_value: Union[int, float],
+ max_value: Union[int, float],
+) -> Union[int, float]:
+ return max(min(value, max_value), min_value)
diff --git a/inference/core/workflows/core_steps/models/foundation/google_gemini/v1.py b/inference/core/workflows/core_steps/models/foundation/google_gemini/v1.py
index 84fb77bc5..8fc8c3271 100644
--- a/inference/core/workflows/core_steps/models/foundation/google_gemini/v1.py
+++ b/inference/core/workflows/core_steps/models/foundation/google_gemini/v1.py
@@ -135,7 +135,10 @@ class BlockManifest(WorkflowBlockManifest):
examples=[{"my_key": "description"}, "$inputs.output_structure"],
json_schema_extra={
"relevant_for": {
- "task_type": {"values": TASKS_REQUIRING_CLASSES, "required": True},
+ "task_type": {
+ "values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
+ "required": True,
+ },
},
},
)
@@ -148,7 +151,7 @@ class BlockManifest(WorkflowBlockManifest):
json_schema_extra={
"relevant_for": {
"task_type": {
- "values": TASKS_REQUIRING_OUTPUT_STRUCTURE,
+ "values": TASKS_REQUIRING_CLASSES,
"required": True,
},
},
diff --git a/inference/models/florence2/florence2.py b/inference/models/florence2/florence2.py
index 0d35bc263..f53d380a5 100644
--- a/inference/models/florence2/florence2.py
+++ b/inference/models/florence2/florence2.py
@@ -2,15 +2,29 @@
from typing import Any, Dict
import torch
+from PIL.Image import Image
from transformers import AutoModelForCausalLM
+from inference.core.entities.responses.inference import LMMInferenceResponse
+from inference.core.models.types import PreprocessReturnMetadata
from inference.models.florence2.utils import import_class_from_file
from inference.models.transformers import LoRATransformerModel, TransformerModel
-class Florence2(TransformerModel):
+class Florence2Processing:
+ def predict(self, image_in: Image, prompt="", history=None, **kwargs):
+ (decoded,) = super().predict(image_in, prompt, history, **kwargs)
+ parsed_answer = self.processor.post_process_generation(
+ decoded, task=prompt.split(">")[0] + ">", image_size=image_in.size
+ )
+
+ return (parsed_answer,)
+
+
+class Florence2(Florence2Processing, TransformerModel):
transformers_class = AutoModelForCausalLM
default_dtype = torch.float32
+ skip_special_tokens = False
def initialize_model(self):
self.transformers_class = import_class_from_file(
@@ -33,7 +47,7 @@ def prepare_generation_params(
}
-class LoRAFlorence2(LoRATransformerModel):
+class LoRAFlorence2(Florence2Processing, LoRATransformerModel):
load_base_from_roboflow = True
transformers_class = AutoModelForCausalLM
default_dtype = torch.float32
diff --git a/inference/models/transformers/transformers.py b/inference/models/transformers/transformers.py
index fa2c8616e..ca5614448 100644
--- a/inference/models/transformers/transformers.py
+++ b/inference/models/transformers/transformers.py
@@ -52,6 +52,7 @@ class TransformerModel(RoboflowInferenceModel):
default_dtype = torch.float16
generation_includes_input = False
needs_hf_token = False
+ skip_special_tokens = True
def __init__(
self, model_id, *args, dtype=None, huggingface_token=HUGGINGFACE_TOKEN, **kwargs
@@ -119,12 +120,17 @@ def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
preprocessed_inputs=model_inputs
)
generation = self.model.generate(
- **prepared_inputs, max_new_tokens=100, do_sample=False
+ **prepared_inputs,
+ max_new_tokens=1000,
+ do_sample=False,
+ early_stopping=False,
)
generation = generation[0]
if self.generation_includes_input:
generation = generation[input_len:]
- decoded = self.processor.decode(generation, skip_special_tokens=True)
+ decoded = self.processor.decode(
+ generation, skip_special_tokens=self.skip_special_tokens
+ )
return (decoded,)
diff --git a/tests/inference/models_predictions_tests/test_florence2.py b/tests/inference/models_predictions_tests/test_florence2.py
index 574410145..5df7ef7a3 100644
--- a/tests/inference/models_predictions_tests/test_florence2.py
+++ b/tests/inference/models_predictions_tests/test_florence2.py
@@ -10,4 +10,4 @@ def test_florence2_caption(
) -> None:
model = Florence2("florence-pretrains/1")
response = model.infer(example_image, prompt="")[0].response
- assert response == "a close up of a dog looking over a fence"
+ assert response == {"": "a close up of a dog looking over a fence"}
diff --git a/tests/workflows/integration_tests/execution/conftest.py b/tests/workflows/integration_tests/execution/conftest.py
index 71f19749c..4c92193a2 100644
--- a/tests/workflows/integration_tests/execution/conftest.py
+++ b/tests/workflows/integration_tests/execution/conftest.py
@@ -59,3 +59,9 @@ def left_scissors_right_scissors() -> np.ndarray:
return cv2.imread(
os.path.join(ROCK_PAPER_SCISSORS_ASSETS, "left_scissors_right_scissors.jpg")
)
+
+
+def bool_env(val):
+ if isinstance(val, bool):
+ return val
+ return val.lower() in ["true", "1", "t", "y", "yes"]
diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_bounding_rectangle.py b/tests/workflows/integration_tests/execution/test_workflow_with_bounding_rectangle.py
index 71ed9072f..2290b7a5f 100644
--- a/tests/workflows/integration_tests/execution/test_workflow_with_bounding_rectangle.py
+++ b/tests/workflows/integration_tests/execution/test_workflow_with_bounding_rectangle.py
@@ -27,10 +27,14 @@
"type": "roboflow_core/bounding_rect@v1",
"name": "bounding_rect",
"predictions": "$steps.detection.predictions",
- }
+ },
],
"outputs": [
- {"type": "JsonField", "name": "result", "selector": "$steps.bounding_rect.detections_with_rect"}
+ {
+ "type": "JsonField",
+ "name": "result",
+ "selector": "$steps.bounding_rect.detections_with_rect",
+ }
],
}
@@ -70,15 +74,35 @@ def test_rectangle_bounding_workflow(
# then
assert len(result) == 1, "One set ot outputs expected"
assert "result" in result[0], "Output must contain key 'result'"
- assert isinstance(result[0]["result"], sv.Detections), "Output must be instance of sv.Detections"
+ assert isinstance(
+ result[0]["result"], sv.Detections
+ ), "Output must be instance of sv.Detections"
assert len(result[0]["result"]) == 2, "Two dogs on the image"
- assert "rect" in result[0]["result"].data, "'rect' data field must expected to be found in result"
- assert "width" in result[0]["result"].data, "'width' data field must expected to be found in result"
- assert "height" in result[0]["result"].data, "'height' data field must expected to be found in result"
- assert "angle" in result[0]["result"].data, "'angle' data field must expected to be found in result"
+ assert (
+ "rect" in result[0]["result"].data
+ ), "'rect' data field must expected to be found in result"
+ assert (
+ "width" in result[0]["result"].data
+ ), "'width' data field must expected to be found in result"
+ assert (
+ "height" in result[0]["result"].data
+ ), "'height' data field must expected to be found in result"
+ assert (
+ "angle" in result[0]["result"].data
+ ), "'angle' data field must expected to be found in result"
- assert np.allclose(result[0]["result"]["rect"][0], np.array([[322.0, 402.0], [325.0, 224.0], [586.0, 228.0], [583.0, 406.0]]))
- assert np.allclose(result[0]["result"]["rect"][1], np.array([[219.0, 82.0], [352.0, 57.0], [409.0, 363.0], [276.0, 388.0]]))
- assert np.allclose(result[0]["result"]["width"], np.array([261.5, 311.25]), atol=0.1)
- assert np.allclose(result[0]["result"]["height"], np.array([178.4, 135.2]), atol=0.1)
+ assert np.allclose(
+ result[0]["result"]["rect"][0],
+ np.array([[322.0, 402.0], [325.0, 224.0], [586.0, 228.0], [583.0, 406.0]]),
+ )
+ assert np.allclose(
+ result[0]["result"]["rect"][1],
+ np.array([[219.0, 82.0], [352.0, 57.0], [409.0, 363.0], [276.0, 388.0]]),
+ )
+ assert np.allclose(
+ result[0]["result"]["width"], np.array([261.5, 311.25]), atol=0.1
+ )
+ assert np.allclose(
+ result[0]["result"]["height"], np.array([178.4, 135.2]), atol=0.1
+ )
assert np.allclose(result[0]["result"]["angle"], np.array([0.826, 79.5]), atol=0.1)
diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_florence2.py b/tests/workflows/integration_tests/execution/test_workflow_with_florence2.py
new file mode 100644
index 000000000..35e3c9d93
--- /dev/null
+++ b/tests/workflows/integration_tests/execution/test_workflow_with_florence2.py
@@ -0,0 +1,585 @@
+import copy
+import json
+import os
+
+import numpy as np
+import pytest
+
+from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
+from inference.core.managers.base import ModelManager
+from inference.core.workflows.core_steps.common.entities import StepExecutionMode
+from inference.core.workflows.execution_engine.core import ExecutionEngine
+from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
+from tests.workflows.integration_tests.execution.conftest import bool_env
+from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import (
+ add_to_workflows_gallery,
+)
+
+FLORENCE2_GROUNDED_CLASSIFICATION_WORKFLOW_DEFINITION = {
+ "version": "1.0",
+ "inputs": [
+ {"type": "InferenceImage", "name": "image"},
+ {"type": "WorkflowParameter", "name": "confidence", "default_value": 0.4},
+ ],
+ "steps": [
+ {
+ "type": "roboflow_core/roboflow_object_detection_model@v1",
+ "name": "model_1",
+ "images": "$inputs.image",
+ "model_id": "yolov8n-640",
+ "confidence": "$inputs.confidence",
+ },
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "detection-grounded-classification",
+ "grounding_detection": "$steps.model_1.predictions",
+ "grounding_selection_mode": "most-confident",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "model_predictions",
+ "coordinates_system": "own",
+ "selector": "$steps.model.*",
+ }
+ ],
+}
+
+
+@add_to_workflows_gallery(
+ category="Workflows with Visual Language Models",
+ use_case_title="Florence 2 - grounded classification",
+ use_case_description="""
+**THIS EXAMPLE CAN ONLY BE RUN LOCALLY OR USING DEDICATED DEPLOYMENT**
+
+In this example, we use object detection model to find regions of interest in the
+input image, which are later classified by Florence 2 model. With Workflows it is possible
+to pass `grounding_detection` as an input for all of the tasks named `detection-grounded-*`.
+
+Grounding detection can either be input parameter or output of detection model. If the
+latter is true, one should choose `grounding_selection_mode` - as Florence do only support
+a single bounding box as grounding - when multiple detections can be provided, block
+will select one based on parameter.
+ """,
+ workflow_definition=FLORENCE2_GROUNDED_CLASSIFICATION_WORKFLOW_DEFINITION,
+ workflow_name_in_app="florence-2-detection-grounded-classification",
+)
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_grounded_classification(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE2_GROUNDED_CLASSIFICATION_WORKFLOW_DEFINITION,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "model_predictions",
+ }, "Expected all declared outputs to be delivered"
+
+ assert json.loads(result[0]["model_predictions"]["raw_output"]).startswith(
+ "dog"
+ ), "Expected dog to be output by florence2"
+
+
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_grounded_classification_when_no_grounding_available(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE2_GROUNDED_CLASSIFICATION_WORKFLOW_DEFINITION,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ "confidence": 0.99,
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "model_predictions",
+ }, "Expected all declared outputs to be delivered"
+ assert result[0]["model_predictions"]["raw_output"] is None, "Expected no output"
+ assert result[0]["model_predictions"]["parsed_output"] is None, "Expected no output"
+ assert result[0]["model_predictions"]["classes"] is None, "Expected no output"
+
+
+FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION = {
+ "version": "1.0",
+ "inputs": [{"type": "InferenceImage", "name": "image"}],
+ "steps": [
+ {
+ "type": "roboflow_core/roboflow_object_detection_model@v1",
+ "name": "model_1",
+ "images": "$inputs.image",
+ "model_id": "yolov8n-640",
+ },
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "detection-grounded-instance-segmentation",
+ "grounding_detection": "$steps.model_1.predictions",
+ "grounding_selection_mode": "most-confident",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "model_predictions",
+ "coordinates_system": "own",
+ "selector": "$steps.model.*",
+ }
+ ],
+}
+
+
+@add_to_workflows_gallery(
+ category="Workflows with Visual Language Models",
+ use_case_title="Florence 2 - grounded segmentation",
+ use_case_description="""
+**THIS EXAMPLE CAN ONLY BE RUN LOCALLY OR USING DEDICATED DEPLOYMENT**
+
+In this example, we use object detection model to find regions of interest in the
+input image and run segmentation of selected region with Florence 2. With Workflows it is
+possible to pass `grounding_detection` as an input for all of the tasks named
+`detection-grounded-*`.
+
+Grounding detection can either be input parameter or output of detection model. If the
+latter is true, one should choose `grounding_selection_mode` - as Florence do only support
+a single bounding box as grounding - when multiple detections can be provided, block
+will select one based on parameter.
+ """,
+ workflow_definition=FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION,
+ workflow_name_in_app="florence-2-detection-grounded-segmentation",
+)
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_grounded_instance_segmentation(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "model_predictions",
+ }, "Expected all declared outputs to be delivered"
+ assert isinstance(
+ result[0]["model_predictions"]["parsed_output"]["polygons"], list
+ ), "Expected list of polygons in the output"
+
+
+FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_GROUNDED_BY_INPUT_WORKFLOW_DEFINITION = {
+ "version": "1.0",
+ "inputs": [
+ {"type": "InferenceImage", "name": "image"},
+ {"type": "WorkflowParameter", "name": "bounding_box"},
+ ],
+ "steps": [
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "detection-grounded-instance-segmentation",
+ "grounding_detection": "$inputs.bounding_box",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "model_predictions",
+ "coordinates_system": "own",
+ "selector": "$steps.model.*",
+ }
+ ],
+}
+
+
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_instance_segmentation_grounded_by_input(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ "bounding_box": [],
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "model_predictions",
+ }, "Expected all declared outputs to be delivered"
+ assert isinstance(
+ result[0]["model_predictions"]["parsed_output"]["polygons"], list
+ ), "Expected list of polygons in the output"
+
+
+FLORENCE2_GROUNDED_CAPTION_WORKFLOW_DEFINITION = {
+ "version": "1.0",
+ "inputs": [{"type": "InferenceImage", "name": "image"}],
+ "steps": [
+ {
+ "type": "roboflow_core/roboflow_object_detection_model@v1",
+ "name": "model_1",
+ "images": "$inputs.image",
+ "model_id": "yolov8n-640",
+ },
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "detection-grounded-caption",
+ "grounding_detection": "$steps.model_1.predictions",
+ "grounding_selection_mode": "most-confident",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "model_predictions",
+ "coordinates_system": "own",
+ "selector": "$steps.model.*",
+ }
+ ],
+}
+
+
+@add_to_workflows_gallery(
+ category="Workflows with Visual Language Models",
+ use_case_title="Florence 2 - grounded captioning",
+ use_case_description="""
+**THIS EXAMPLE CAN ONLY BE RUN LOCALLY OR USING DEDICATED DEPLOYMENT**
+
+In this example, we use object detection model to find regions of interest in the
+input image and run captioning of selected region with Florence 2. With Workflows it is
+possible to pass `grounding_detection` as an input for all of the tasks named
+`detection-grounded-*`.
+
+Grounding detection can either be input parameter or output of detection model. If the
+latter is true, one should choose `grounding_selection_mode` - as Florence do only support
+a single bounding box as grounding - when multiple detections can be provided, block
+will select one based on parameter.
+ """,
+ workflow_definition=FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION,
+ workflow_name_in_app="florence-2-detection-grounded-caption",
+)
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_grounded_caption(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE2_GROUNDED_CAPTION_WORKFLOW_DEFINITION,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "model_predictions",
+ }, "Expected all declared outputs to be delivered"
+
+ assert json.loads(result[0]["model_predictions"]["raw_output"]).startswith(
+ "dog"
+ ), "Expected dog to be output by florence2"
+
+
+FLORENCE_OBJECT_DETECTION_WORKFLOW = {
+ "version": "1.0",
+ "inputs": [
+ {"type": "InferenceImage", "name": "image"},
+ {"type": "WorkflowParameter", "name": "classes"},
+ ],
+ "steps": [
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "open-vocabulary-object-detection",
+ "classes": "$inputs.classes",
+ },
+ {
+ "type": "roboflow_core/vlm_as_detector@v1",
+ "name": "vlm_as_detector",
+ "image": "$inputs.image",
+ "vlm_output": "$steps.model.raw_output",
+ "classes": "$steps.model.classes",
+ "model_type": "florence-2",
+ "task_type": "open-vocabulary-object-detection",
+ },
+ {
+ "type": "roboflow_core/bounding_box_visualization@v1",
+ "name": "bounding_box_visualization",
+ "image": "$inputs.image",
+ "predictions": "$steps.vlm_as_detector.predictions",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "predictions",
+ "selector": "$steps.vlm_as_detector.predictions",
+ },
+ {
+ "type": "JsonField",
+ "name": "bounding_box_visualization",
+ "coordinates_system": "own",
+ "selector": "$steps.bounding_box_visualization.image",
+ },
+ ],
+}
+
+
+@add_to_workflows_gallery(
+ category="Workflows with Visual Language Models",
+ use_case_title="Florence 2 - object detection",
+ use_case_description="""
+**THIS EXAMPLE CAN ONLY BE RUN LOCALLY OR USING DEDICATED DEPLOYMENT**
+
+In this example, we use Florence 2 as zero-shot object detection model, specifically
+performing open-vocabulary detection. Input parameter `classes` can be used to
+provide list of objects that model should find. Beware that Florence 2 is prone to
+seek for all of the classes provided in your list - so if you select class which is not
+visible in the image, you can expect either big bounding box covering whole image,
+or multiple bounding boxes over one of detected instance, with auxiliary boxes
+providing not meaningful labels for all of the objects you specified in class list.
+ """,
+ workflow_definition=FLORENCE2_GROUNDED_INSTANCE_SEGMENTATION_WORKFLOW_DEFINITION,
+ workflow_name_in_app="florence-2-detection-grounded-caption",
+)
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+def test_florence2_object_detection(
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=FLORENCE_OBJECT_DETECTION_WORKFLOW,
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={"image": dogs_image, "classes": ["dog"]}
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert set(result[0].keys()) == {
+ "predictions",
+ "bounding_box_visualization",
+ }, "Expected all declared outputs to be delivered"
+ assert len(result[0]["predictions"]) == 2, "Expected two predictions"
+ assert result[0]["predictions"].data["class_name"].tolist() == [
+ "dog",
+ "dog",
+ ], "Expected two dogs to be found"
+
+
+FLORENCE_OD_TASK_TYPES = [
+ "object-detection",
+ "open-vocabulary-object-detection",
+ "object-detection-and-caption",
+ "phrase-grounded-object-detection",
+ "region-proposal",
+ "ocr-with-text-detection",
+]
+
+
+FLORENCE_VLM_AS_DET_VISUALIZE_DEF = {
+ "version": "1.0",
+ "inputs": [{"type": "InferenceImage", "name": "image"}],
+ "steps": [
+ {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "task_type": "object-detection",
+ },
+ {
+ "type": "roboflow_core/vlm_as_detector@v1",
+ "name": "vlm_as_detector",
+ "image": "$inputs.image",
+ "vlm_output": "$steps.model.raw_output",
+ "classes": "$steps.model.classes",
+ "model_type": "florence-2",
+ "task_type": "object-detection",
+ },
+ {
+ "type": "roboflow_core/bounding_box_visualization@v1",
+ "name": "bounding_box_visualization",
+ "image": "$inputs.image",
+ "predictions": "$steps.vlm_as_detector.predictions",
+ },
+ ],
+ "outputs": [
+ {
+ "type": "JsonField",
+ "name": "model_predictions",
+ "coordinates_system": "own",
+ "selector": "$steps.model.*",
+ },
+ {
+ "type": "JsonField",
+ "name": "vlm_as_detector",
+ "coordinates_system": "own",
+ "selector": "$steps.vlm_as_detector.*",
+ },
+ {
+ "type": "JsonField",
+ "name": "bounding_box_visualization",
+ "coordinates_system": "own",
+ "selector": "$steps.bounding_box_visualization.image",
+ },
+ ],
+}
+
+
+def make_visualize_workflow(task_type):
+ wf_def = copy.deepcopy(FLORENCE_VLM_AS_DET_VISUALIZE_DEF)
+ if task_type == "phrase-grounded-object-detection":
+ wf_def["steps"][0]["prompt"] = "dog"
+ elif task_type == "open-vocabulary-object-detection":
+ wf_def["steps"][0]["classes"] = ["dog"]
+ wf_def["steps"][0]["task_type"] = task_type
+ wf_def["steps"][1]["task_type"] = task_type
+ return wf_def
+
+
+@pytest.mark.skipif(
+ bool_env(os.getenv("SKIP_FLORENCE2_TEST", True)), reason="Skipping Florence 2 test"
+)
+@pytest.mark.parametrize("task_type", FLORENCE_OD_TASK_TYPES)
+def test_florence_visualization_with_vlm_as_detector_all_variations(
+ task_type,
+ model_manager: ModelManager,
+ dogs_image: np.ndarray,
+ roboflow_api_key: str,
+) -> None:
+ # given
+ workflow_init_parameters = {
+ "workflows_core.model_manager": model_manager,
+ "workflows_core.api_key": roboflow_api_key,
+ "workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
+ }
+ execution_engine = ExecutionEngine.init(
+ workflow_definition=make_visualize_workflow(task_type),
+ init_parameters=workflow_init_parameters,
+ max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
+ )
+
+ # when
+ result = execution_engine.run(
+ runtime_parameters={
+ "image": dogs_image,
+ }
+ )
+
+ assert isinstance(result, list), "Expected list to be delivered"
+ assert len(result) == 1, "Expected 1 element in the output for one input image"
+ assert isinstance(result[0]["bounding_box_visualization"], WorkflowImageData)
diff --git a/tests/workflows/unit_tests/core_steps/formatters/test_vlm_as_detector.py b/tests/workflows/unit_tests/core_steps/formatters/test_vlm_as_detector.py
index 10f013c26..d5eada03d 100644
--- a/tests/workflows/unit_tests/core_steps/formatters/test_vlm_as_detector.py
+++ b/tests/workflows/unit_tests/core_steps/formatters/test_vlm_as_detector.py
@@ -161,3 +161,244 @@ def test_run_method_for_invalid_json() -> None:
assert result["error_status"] is True
assert result["predictions"] is None
assert len(result["inference_id"]) > 0
+
+
+def test_formatter_for_florence2_object_detection() -> None:
+ # given
+ block = VLMAsDetectorBlockV1()
+ image = WorkflowImageData(
+ numpy_image=np.zeros((192, 168, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ vlm_output = """
+{"bboxes": [[434.0, 30.848499298095703, 760.4000244140625, 530.4144897460938], [0.4000000059604645, 96.13949584960938, 528.4000244140625, 564.5574951171875]], "labels": ["cat", "dog"]}
+"""
+
+ # when
+ result = block.run(
+ image=image,
+ vlm_output=vlm_output,
+ classes=["cat", "dog"],
+ model_type="florence-2",
+ task_type="object-detection",
+ )
+
+ # then
+ assert result["error_status"] is False
+ assert isinstance(result["predictions"], sv.Detections)
+ assert len(result["inference_id"]) > 0
+ assert np.allclose(
+ result["predictions"].xyxy,
+ np.array([[434, 30.848, 760.4, 530.41], [0.4, 96.139, 528.4, 564.56]]),
+ atol=1e-1,
+ ), "Expected coordinates to be the same as given in raw input"
+ assert result["predictions"].class_id.tolist() == [7725, 5324]
+ assert np.allclose(result["predictions"].confidence, np.array([1.0, 1.0]))
+ assert result["predictions"].data["class_name"].tolist() == ["cat", "dog"]
+ assert "class_name" in result["predictions"].data
+ assert "image_dimensions" in result["predictions"].data
+ assert "prediction_type" in result["predictions"].data
+ assert "parent_coordinates" in result["predictions"].data
+ assert "parent_dimensions" in result["predictions"].data
+ assert "root_parent_coordinates" in result["predictions"].data
+ assert "root_parent_dimensions" in result["predictions"].data
+ assert "parent_id" in result["predictions"].data
+ assert "root_parent_id" in result["predictions"].data
+
+
+def test_formatter_for_florence2_open_vocabulary_object_detection() -> None:
+ # given
+ block = VLMAsDetectorBlockV1()
+ image = WorkflowImageData(
+ numpy_image=np.zeros((192, 168, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ vlm_output = """
+{"bboxes": [[434.0, 30.848499298095703, 760.4000244140625, 530.4144897460938], [0.4000000059604645, 96.13949584960938, 528.4000244140625, 564.5574951171875]], "bboxes_labels": ["cat", "dog"]}
+"""
+
+ # when
+ result = block.run(
+ image=image,
+ vlm_output=vlm_output,
+ classes=["cat", "dog"],
+ model_type="florence-2",
+ task_type="open-vocabulary-object-detection",
+ )
+
+ # then
+ assert result["error_status"] is False
+ assert isinstance(result["predictions"], sv.Detections)
+ assert len(result["inference_id"]) > 0
+ assert np.allclose(
+ result["predictions"].xyxy,
+ np.array([[434, 30.848, 760.4, 530.41], [0.4, 96.139, 528.4, 564.56]]),
+ atol=1e-1,
+ ), "Expected coordinates to be the same as given in raw input"
+ assert result["predictions"].class_id.tolist() == [0, 1]
+ assert np.allclose(result["predictions"].confidence, np.array([1.0, 1.0]))
+ assert result["predictions"].data["class_name"].tolist() == ["cat", "dog"]
+ assert "class_name" in result["predictions"].data
+ assert "image_dimensions" in result["predictions"].data
+ assert "prediction_type" in result["predictions"].data
+ assert "parent_coordinates" in result["predictions"].data
+ assert "parent_dimensions" in result["predictions"].data
+ assert "root_parent_coordinates" in result["predictions"].data
+ assert "root_parent_dimensions" in result["predictions"].data
+ assert "parent_id" in result["predictions"].data
+ assert "root_parent_id" in result["predictions"].data
+
+
+def test_formatter_for_florence2_phase_grounded_detection() -> None:
+ # given
+ block = VLMAsDetectorBlockV1()
+ image = WorkflowImageData(
+ numpy_image=np.zeros((192, 168, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ vlm_output = """
+{"bboxes": [[434.0, 30.848499298095703, 760.4000244140625, 530.4144897460938], [0.4000000059604645, 96.13949584960938, 528.4000244140625, 564.5574951171875]], "labels": ["cat", "dog"]}
+"""
+
+ # when
+ result = block.run(
+ image=image,
+ vlm_output=vlm_output,
+ classes=["cat", "dog"],
+ model_type="florence-2",
+ task_type="phrase-grounded-object-detection",
+ )
+
+ # then
+ assert result["error_status"] is False
+ assert isinstance(result["predictions"], sv.Detections)
+ assert len(result["inference_id"]) > 0
+ assert np.allclose(
+ result["predictions"].xyxy,
+ np.array([[434, 30.848, 760.4, 530.41], [0.4, 96.139, 528.4, 564.56]]),
+ atol=1e-1,
+ ), "Expected coordinates to be the same as given in raw input"
+ assert result["predictions"].class_id.tolist() == [7725, 5324]
+ assert np.allclose(result["predictions"].confidence, np.array([1.0, 1.0]))
+ assert result["predictions"].data["class_name"].tolist() == ["cat", "dog"]
+ assert "class_name" in result["predictions"].data
+ assert "image_dimensions" in result["predictions"].data
+ assert "prediction_type" in result["predictions"].data
+ assert "parent_coordinates" in result["predictions"].data
+ assert "parent_dimensions" in result["predictions"].data
+ assert "root_parent_coordinates" in result["predictions"].data
+ assert "root_parent_dimensions" in result["predictions"].data
+ assert "parent_id" in result["predictions"].data
+ assert "root_parent_id" in result["predictions"].data
+
+
+def test_formatter_for_florence2_region_proposal() -> None:
+ # given
+ block = VLMAsDetectorBlockV1()
+ image = WorkflowImageData(
+ numpy_image=np.zeros((192, 168, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ vlm_output = """
+{"bboxes": [[434.0, 30.848499298095703, 760.4000244140625, 530.4144897460938], [0.4000000059604645, 96.13949584960938, 528.4000244140625, 564.5574951171875]], "labels": ["", ""]}
+"""
+
+ # when
+ result = block.run(
+ image=image,
+ vlm_output=vlm_output,
+ classes=[],
+ model_type="florence-2",
+ task_type="region-proposal",
+ )
+
+ # then
+ assert result["error_status"] is False
+ assert isinstance(result["predictions"], sv.Detections)
+ assert len(result["inference_id"]) > 0
+ assert np.allclose(
+ result["predictions"].xyxy,
+ np.array([[434, 30.848, 760.4, 530.41], [0.4, 96.139, 528.4, 564.56]]),
+ atol=1e-1,
+ ), "Expected coordinates to be the same as given in raw input"
+ assert result["predictions"].class_id.tolist() == [0, 0]
+ assert np.allclose(result["predictions"].confidence, np.array([1.0, 1.0]))
+ assert result["predictions"].data["class_name"].tolist() == ["roi", "roi"]
+ assert "class_name" in result["predictions"].data
+ assert "image_dimensions" in result["predictions"].data
+ assert "prediction_type" in result["predictions"].data
+ assert "parent_coordinates" in result["predictions"].data
+ assert "parent_dimensions" in result["predictions"].data
+ assert "root_parent_coordinates" in result["predictions"].data
+ assert "root_parent_dimensions" in result["predictions"].data
+ assert "parent_id" in result["predictions"].data
+ assert "root_parent_id" in result["predictions"].data
+
+
+def test_formatter_for_florence2_ocr() -> None:
+ # given
+ block = VLMAsDetectorBlockV1()
+ image = WorkflowImageData(
+ numpy_image=np.zeros((192, 168, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ vlm_output = """
+{"quad_boxes": [[336.9599914550781, 77.22000122070312, 770.8800048828125, 77.22000122070312, 770.8800048828125, 144.1800079345703, 336.9599914550781, 144.1800079345703], [1273.919921875, 77.22000122070312, 1473.5999755859375, 77.22000122070312, 1473.5999755859375, 109.62000274658203, 1273.919921875, 109.62000274658203], [1652.159912109375, 72.9000015258789, 1828.7999267578125, 70.74000549316406, 1828.7999267578125, 129.05999755859375, 1652.159912109375, 131.22000122070312], [1273.919921875, 126.9000015258789, 1467.8399658203125, 126.9000015258789, 1467.8399658203125, 160.3800048828125, 1273.919921875, 160.3800048828125], [340.79998779296875, 173.3400115966797, 964.7999877929688, 173.3400115966797, 964.7999877929688, 250.02000427246094, 340.79998779296875, 251.10000610351562], [1273.919921875, 177.66000366210938, 1473.5999755859375, 177.66000366210938, 1473.5999755859375, 208.98001098632812, 1273.919921875, 208.98001098632812], [1272.0, 226.260009765625, 1467.8399658203125, 226.260009765625, 1467.8399658203125, 259.7400207519531, 1272.0, 259.7400207519531], [340.79998779296875, 264.05999755859375, 801.5999755859375, 264.05999755859375, 801.5999755859375, 345.0600280761719, 340.79998779296875, 345.0600280761719], [1273.919921875, 277.02001953125, 1471.679931640625, 277.02001953125, 1471.679931640625, 309.4200134277344, 1273.919921875, 309.4200134277344], [1273.919921875, 326.70001220703125, 1467.8399658203125, 326.70001220703125, 1467.8399658203125, 359.1000061035156, 1273.919921875, 359.1000061035156], [336.9599914550781, 376.3800048828125, 980.1599731445312, 376.3800048828125, 980.1599731445312, 417.4200134277344, 336.9599914550781, 417.4200134277344]], "labels": ["What is OCR", "01010110", "veryfi", "010100101", "(Optical Character", "01010010", "011100101", "Recognition?", "0101010", "01010001", "A Friendly Introduction to OCR Software"]}
+"""
+
+ # when
+ result = block.run(
+ image=image,
+ vlm_output=vlm_output,
+ classes=[],
+ model_type="florence-2",
+ task_type="ocr-with-text-detection",
+ )
+
+ # then
+ assert result["error_status"] is False
+ assert isinstance(result["predictions"], sv.Detections)
+ assert len(result["inference_id"]) > 0
+ assert np.allclose(
+ result["predictions"].xyxy,
+ np.array(
+ [
+ [336.96, 77.22, 770.88, 144.18],
+ [1273.9, 77.22, 1473.6, 109.62],
+ [1652.2, 70.74, 1828.8, 131.22],
+ [1273.9, 126.9, 1467.8, 160.38],
+ [340.8, 173.34, 964.8, 251.1],
+ [1273.9, 177.66, 1473.6, 208.98],
+ [1272, 226.26, 1467.8, 259.74],
+ [340.8, 264.06, 801.6, 345.06],
+ [1273.9, 277.02, 1471.7, 309.42],
+ [1273.9, 326.7, 1467.8, 359.1],
+ [336.96, 376.38, 980.16, 417.42],
+ ]
+ ),
+ atol=1e-1,
+ ), "Expected coordinates to be the same as given in raw input"
+ assert result["predictions"].class_id.tolist() == [0] * 11
+ assert np.allclose(result["predictions"].confidence, np.array([1.0] * 11))
+ assert result["predictions"].data["class_name"].tolist() == [
+ "What is OCR",
+ "01010110",
+ "veryfi",
+ "010100101",
+ "(Optical Character",
+ "01010010",
+ "011100101",
+ "Recognition?",
+ "0101010",
+ "01010001",
+ "A Friendly Introduction to OCR Software",
+ ]
+ assert "class_name" in result["predictions"].data
+ assert "image_dimensions" in result["predictions"].data
+ assert "prediction_type" in result["predictions"].data
+ assert "parent_coordinates" in result["predictions"].data
+ assert "parent_dimensions" in result["predictions"].data
+ assert "root_parent_coordinates" in result["predictions"].data
+ assert "root_parent_dimensions" in result["predictions"].data
+ assert "parent_id" in result["predictions"].data
+ assert "root_parent_id" in result["predictions"].data
diff --git a/tests/workflows/unit_tests/core_steps/models/foundation/test_florence2.py b/tests/workflows/unit_tests/core_steps/models/foundation/test_florence2.py
new file mode 100644
index 000000000..354133809
--- /dev/null
+++ b/tests/workflows/unit_tests/core_steps/models/foundation/test_florence2.py
@@ -0,0 +1,273 @@
+from typing import List, Union
+
+import numpy as np
+import pytest
+import supervision as sv
+from pydantic import ValidationError
+
+from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
+ BlockManifest,
+ prepare_detection_grounding_prompts,
+)
+from inference.core.workflows.execution_engine.entities.base import (
+ Batch,
+ ImageParentMetadata,
+ WorkflowImageData,
+)
+
+
+@pytest.mark.parametrize(
+ "task",
+ ["phrase-grounded-object-detection", "phrase-grounded-instance-segmentation"],
+)
+@pytest.mark.parametrize("image_field", ["image", "images"])
+def test_florence2_manifest_for_prompt_requiring_tasks(
+ task: str,
+ image_field: str,
+) -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ image_field: "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": task,
+ "prompt": "my_prompt",
+ }
+
+ # when
+ result = BlockManifest.model_validate(manifest)
+
+ # then
+ assert result == BlockManifest(
+ type="roboflow_core/florence_2@v1",
+ name="model",
+ images="$inputs.image",
+ model_version="florence-2-base",
+ task_type=task,
+ prompt="my_prompt",
+ )
+
+
+@pytest.mark.parametrize("task", ["open-vocabulary-object-detection"])
+@pytest.mark.parametrize("image_field", ["image", "images"])
+def test_florence2_manifest_for_classes_requiring_tasks(
+ task: str,
+ image_field: str,
+) -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ image_field: "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": task,
+ "classes": ["a", "b"],
+ }
+
+ # when
+ result = BlockManifest.model_validate(manifest)
+
+ # then
+ assert result == BlockManifest(
+ type="roboflow_core/florence_2@v1",
+ name="model",
+ images="$inputs.image",
+ model_version="florence-2-base",
+ task_type=task,
+ classes=["a", "b"],
+ )
+
+
+@pytest.mark.parametrize(
+ "task",
+ [
+ "detection-grounded-instance-segmentation",
+ "detection-grounded-classification",
+ "detection-grounded-caption",
+ "detection-grounded-ocr",
+ ],
+)
+@pytest.mark.parametrize("image_field", ["image", "images"])
+@pytest.mark.parametrize(
+ "grounding_detection",
+ ["$inputs.bbox", "$steps.model.predictions", [0, 1, 2, 3], [0.0, 1.0, 0.0, 1.0]],
+)
+def test_florence2_manifest_for_classes_requiring_detection_grounding(
+ task: str,
+ image_field: str,
+ grounding_detection: Union[List[int], List[float], str],
+) -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ image_field: "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": task,
+ "grounding_detection": grounding_detection,
+ }
+
+ # when
+ result = BlockManifest.model_validate(manifest)
+
+ # then
+ assert result == BlockManifest(
+ type="roboflow_core/florence_2@v1",
+ name="model",
+ images="$inputs.image",
+ model_version="florence-2-base",
+ task_type=task,
+ grounding_detection=grounding_detection,
+ )
+
+
+def test_manifest_parsing_when_classes_not_given_but_should_have() -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": "open-vocabulary-object-detection",
+ }
+
+ # when
+ with pytest.raises(ValidationError):
+ _ = BlockManifest.model_validate(manifest)
+
+
+def test_manifest_parsing_when_prompt_not_given_but_should_have() -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": "phrase-grounded-object-detection",
+ }
+
+ # when
+ with pytest.raises(ValidationError):
+ _ = BlockManifest.model_validate(manifest)
+
+
+def test_manifest_parsing_when_detection_grounding_not_given_but_should_have() -> None:
+ # given
+ manifest = {
+ "type": "roboflow_core/florence_2@v1",
+ "name": "model",
+ "images": "$inputs.image",
+ "model_version": "florence-2-base",
+ "task_type": "detection-grounded-instance-segmentation",
+ }
+
+ # when
+ with pytest.raises(ValidationError):
+ _ = BlockManifest.model_validate(manifest)
+
+
+def test_prepare_detection_grounding_prompts_when_empty_sv_detections_given() -> None:
+ # given
+ image = WorkflowImageData(
+ numpy_image=np.zeros((100, 200, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ detections = sv.Detections.empty()
+
+ # when
+ result = prepare_detection_grounding_prompts(
+ images=Batch(content=[image], indices=[(0,)]),
+ grounding_detection=Batch(content=[detections], indices=[(0,)]),
+ grounding_selection_mode="most-confident",
+ )
+
+ # then
+ assert result == [None]
+
+
+def test_prepare_detection_grounding_prompts_when_non_empty_sv_detections_given() -> (
+ None
+):
+ # given
+ image = WorkflowImageData(
+ numpy_image=np.zeros((100, 200, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ detections = sv.Detections(
+ xyxy=np.array([[60, 30, 100, 50], [10, 20, 30, 40]]),
+ confidence=np.array([0.7, 0.6]),
+ )
+
+ # when
+ result = prepare_detection_grounding_prompts(
+ images=Batch(content=[image], indices=[(0,)]),
+ grounding_detection=Batch(content=[detections], indices=[(0,)]),
+ grounding_selection_mode="most-confident",
+ )
+
+ # then
+ assert result == [""]
+
+
+def test_prepare_detection_grounding_prompts_when_batch_of_sv_detections_given() -> (
+ None
+):
+ # given
+ image = WorkflowImageData(
+ numpy_image=np.zeros((100, 200, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+ detections = sv.Detections(
+ xyxy=np.array([[60, 30, 100, 50], [50, 10, 100, 40]]),
+ confidence=np.array([0.7, 0.6]),
+ )
+
+ # when
+ result = prepare_detection_grounding_prompts(
+ images=Batch(content=[image, image], indices=[(0,), (1,)]),
+ grounding_detection=Batch(
+ content=[sv.Detections.empty(), detections],
+ indices=[(0,), (1,)],
+ ),
+ grounding_selection_mode="least-confident",
+ )
+
+ # then
+ assert result == [None, ""]
+
+
+def test_prepare_detection_grounding_prompts_list_of_int_given() -> None:
+ # given
+ image = WorkflowImageData(
+ numpy_image=np.zeros((100, 200, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+
+ # when
+ result = prepare_detection_grounding_prompts(
+ images=Batch(content=[image], indices=[(0,)]),
+ grounding_detection=[60, 30, 100, 50],
+ grounding_selection_mode="most-confident",
+ )
+
+ # then
+ assert result == [""]
+
+
+def test_prepare_detection_grounding_prompts_list_of_float_given() -> None:
+ # given
+ image = WorkflowImageData(
+ numpy_image=np.zeros((100, 200, 3), dtype=np.uint8),
+ parent_metadata=ImageParentMetadata(parent_id="parent"),
+ )
+
+ # when
+ result = prepare_detection_grounding_prompts(
+ images=Batch(content=[image], indices=[(0,)]),
+ grounding_detection=[0.3, 0.3, 0.5, 0.5],
+ grounding_selection_mode="most-confident",
+ )
+
+ # then
+ assert result == [""]
diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_bounding_rect.py b/tests/workflows/unit_tests/core_steps/transformations/test_bounding_rect.py
index b16e252f5..da5f88b13 100644
--- a/tests/workflows/unit_tests/core_steps/transformations/test_bounding_rect.py
+++ b/tests/workflows/unit_tests/core_steps/transformations/test_bounding_rect.py
@@ -11,15 +11,7 @@
def test_calculate_minimum_bounding_rectangle():
# given
- polygon = np.array(
- [
- [10, 10],
- [10, 1],
- [20, 1],
- [20, 10],
- [15, 5]
- ]
- )
+ polygon = np.array([[10, 10], [10, 1], [20, 1], [20, 10], [15, 5]])
mask = sv.polygon_to_mask(
polygon=polygon, resolution_wh=(np.max(polygon, axis=0) + 10)
)
@@ -29,12 +21,14 @@ def test_calculate_minimum_bounding_rectangle():
# then
expected_box = np.array([[10, 1], [20, 1], [20, 10], [10, 10]])
- assert np.allclose(box, expected_box), (
- f"Expected bounding box to be {expected_box}, but got {box}"
- )
+ assert np.allclose(
+ box, expected_box
+ ), f"Expected bounding box to be {expected_box}, but got {box}"
assert np.isclose(width, 9), f"Expected width to be 9, but got {width}"
assert np.isclose(height, 10), f"Expected height to be 10, but got {height}"
- assert angle == 90 or angle == -90, f"Expected angle to be 90 or -90, but got {angle}"
+ assert (
+ angle == 90 or angle == -90
+ ), f"Expected angle to be 90 or -90, but got {angle}"
@pytest.mark.parametrize("type_alias", ["roboflow_core/bounding_rect@v1"])
@@ -53,9 +47,7 @@ def test_bounding_box_validation_when_valid_manifest_is_given(
# then
assert result == BoundingRectManifest(
- type=type_alias,
- name="bounding_box",
- predictions="$steps.od_model.predictions"
+ type=type_alias, name="bounding_box", predictions="$steps.od_model.predictions"
)
@@ -68,7 +60,7 @@ def test_bounding_box_block() -> None:
[
sv.polygon_to_mask(
polygon=np.array([[10, 10], [10, 100], [100, 100], [100, 10]]),
- resolution_wh=(1000, 1000)
+ resolution_wh=(1000, 1000),
)
]
),
@@ -83,6 +75,9 @@ def test_bounding_box_block() -> None:
assert output["detections_with_rect"].data["height"][0] == 90
assert output["detections_with_rect"].data["width"][0] == 90
assert output["detections_with_rect"].data["angle"][0] == 90
- np.allclose(np.array([[10, 10], [10, 100], [100, 100], [100, 10]]), output["detections_with_rect"].data["rect"][0])
+ np.allclose(
+ np.array([[10, 10], [10, 100], [100, 100], [100, 10]]),
+ output["detections_with_rect"].data["rect"][0],
+ )
# check if the image is modified
assert detections != output["detections_with_rect"]