Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Florence2 workflows block #661

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@
from inference.core.workflows.core_steps.visualizations.triangle.v1 import (
TriangleVisualizationBlockV1,
)
from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
Florence2BlockV1,
)
from inference.core.workflows.execution_engine.entities.types import (
BAR_CODE_DETECTION_KIND,
BOOLEAN_KIND,
Expand Down Expand Up @@ -335,6 +338,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
AntropicClaudeBlockV1,
LineCounterBlockV1,
PolygonZoneVisualizationBlockV1,
Florence2BlockV1,
]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from typing import List, Literal, Optional, Type, TypeVar, Union

from pydantic import ConfigDict, Field, model_validator
import numpy as np
import supervision as sv
from pydantic import ConfigDict, Field

from inference.core.entities.requests.inference import LMMInferenceRequest
from inference.core.entities.responses.inference import LMMInferenceResponse
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.core_steps.common.utils import (
attach_parents_coordinates_to_batch_of_sv_detections,
attach_prediction_type_info_to_sv_detections_batch,
convert_inference_detections_batch_to_sv_detections,
load_core_model,
)
from inference.core.workflows.execution_engine.entities.base import (
Batch,
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
FLOAT_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
KEYPOINT_DETECTION_PREDICTION_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
LANGUAGE_MODEL_OUTPUT_KIND,
STRING_KIND,
ImageInputField,
StepOutputImageSelector,
StepOutputSelector,
WorkflowImageSelector,
WorkflowParameterSelector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

T = TypeVar("T")
K = TypeVar("K")

DETECTIONS_CLASS_NAME_FIELD = "class_name"
DETECTION_ID_FIELD = "detection_id"

LONG_DESCRIPTION = """
Run Florence-2, a large multimodal model, on an image.
** Dedicated inference server required (GPU recomended) **
"""

TaskType = Literal[
"<OCR>",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is change in favour of clients using UI, probably would decrease code readability, but I would still choose that tradeoff - let's change the literal entries into this style:

ocr-with-region instead <OCR_WITH_REGION> and change it in prompting into f"<{value.upper().replace('-', '_')}>"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this was still draft -- I added legible english descriptions and mapped them to tasks... that ok?

"<OCR_WITH_REGION>",
"<CAPTION>",
"<DETAILED_CAPTION>",
"<MORE_DETAILED_CAPTION>",
"<OD>",
"<DENSE_REGION_CAPTION>",
"<CAPTION_TO_PHRASE_GROUNDING>",
"<REFERRING_EXPRESSION_SEGMENTATION>",
"<REGION_TO_SEGMENTATION>",
"<OPEN_VOCABULARY_DETECTION>",
"<REGION_TO_CATEGORY>",
"<REGION_TO_DESCRIPTION>",
"<REGION_TO_OCR>",
"<REGION_PROPOSAL>",
]

TASKS_REQUIRING_PROMPT = [
"<CAPTION_TO_PHRASE_GROUNDING>",
"<REFERRING_EXPRESSION_SEGMENTATION>",
"<REGION_TO_SEGMENTATION>",
"<OPEN_VOCABULARY_DETECTION>",
"<REGION_TO_CATEGORY>",
"<REGION_TO_DESCRIPTION>",
"<REGION_TO_OCR>",
]


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Florence-2 Model",
"version": "v1",
"short_description": "Run Florence-2 on an image",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["Florence", "Florence-2", "Microsoft"],
},
protected_namespaces=(),
)

model_version: Union[
WorkflowParameterSelector(kind=[STRING_KIND]),
Literal["florence-2-base", "florence-2-large"],
] = Field(
default="florence-2-base",
description="Model to be used",
examples=["florence-2-base"],
)
type: Literal["roboflow_core/florence_2@v1"]
images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField
task_type: TaskType = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change into this style: #659 following what @EmilyGavrilenko did

description="Task type to be performed by model. Value of parameter determine set of fields "
"that are required. For `unconstrained`, `visual-question-answering`, "
" - `prompt` parameter must be provided."
"For `structured-answering` - `output-structure` must be provided. For "
"`classification`, `multi-label-classification` and `object-detection` - "
"`classes` must be filled. `ocr`, `caption`, `detailed-caption` do not"
"require any additional parameter.",
)
prompt: Optional[Union[WorkflowParameterSelector(kind=[STRING_KIND]), str]] = Field(
default=None,
description="Text prompt to the Claude model",
examples=["my prompt", "$inputs.prompt"],
json_schema_extra={
"relevant_for": {
"task_type": {"values": TASKS_REQUIRING_PROMPT, "required": True},
},
},
)

@classmethod
def accepts_batch_input(cls) -> bool:
return True


@model_validator(mode="after")
def validate(self) -> "BlockManifest":
if self.task_type in TASKS_REQUIRING_PROMPT and self.prompt is None:
raise ValueError(
f"`prompt` parameter required to be set for task `{self.task_type}`"
)
return self

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(
name="output", kind=[STRING_KIND, LANGUAGE_MODEL_OUTPUT_KIND]
)
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.0.0,<2.0.0"


class Florence2BlockV1(WorkflowBlock):

def __init__(
self,
model_manager: ModelManager,
api_key: Optional[str],
step_execution_mode: StepExecutionMode,
):
self._model_manager = model_manager
self._api_key = api_key
self._step_execution_mode = step_execution_mode

@classmethod
def get_init_parameters(cls) -> List[str]:
return ["model_manager", "api_key", "step_execution_mode"]

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
images: Batch[WorkflowImageData],
task_type: TaskType,
prompt: Optional[str],
model_version: str,
) -> BlockResult:
if self._step_execution_mode is StepExecutionMode.LOCAL:
return self.run_locally(
images=images,
task_type=task_type,
model_version=model_version,
prompt=prompt,
)
elif self._step_execution_mode is StepExecutionMode.REMOTE:
raise NotImplementedError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just exploratory question - am I correct that for base model we could run it in core models lambda given special endpoint is created?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe so. We'd need to enable it with envvar also. It might be pretty dang slow depending on the task.

"Remote execution is not supported for florence2. Run a local or dedicated inference server to use this block (GPU recommended)."
)
else:
raise ValueError(
f"Unknown step execution mode: {self._step_execution_mode}"
)

def run_locally(
self,
images: Batch[WorkflowImageData],
task_type: TaskType,
prompt: Optional[str],
model_version: str,
) -> BlockResult:
inference_images = [i.to_inference_format(numpy_preferred=False) for i in images]
self._model_manager.add_model(
model_id=model_version,
api_key=self._api_key,
)
predictions = []
for image in inference_images:
request = LMMInferenceRequest(
api_key=self._api_key,
model_id=model_version,
image=image,
source="workflow-execution",
prompt=task_type + (prompt or ""),
)
prediction = self._model_manager.infer_from_request_sync(
model_id=model_version, request=request
)
predictions.append(prediction)
return [{"output": prediction.response} for prediction in predictions]
17 changes: 17 additions & 0 deletions inference/models/florence2/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
from typing import Any, Dict

import torch
from PIL.Image import Image
from transformers import AutoModelForCausalLM

from inference.core.entities.responses.inference import LMMInferenceResponse
from inference.core.models.types import PreprocessReturnMetadata
from inference.models.florence2.utils import import_class_from_file
from inference.models.transformers import LoRATransformerModel, TransformerModel

BOS_TOKEN = "<s>"
EOS_TOKEN = "</s>"


class Florence2(TransformerModel):
transformers_class = AutoModelForCausalLM
default_dtype = torch.float32
skip_special_tokens = False

def initialize_model(self):
self.transformers_class = import_class_from_file(
Expand All @@ -32,6 +39,11 @@ def prepare_generation_params(
"pixel_values": preprocessed_inputs["pixel_values"],
}

def predict(self, image_in: Image, prompt="", history=None, **kwargs):
(preds,) = super().predict(image_in, prompt, history, **kwargs)
preds = preds.replace(BOS_TOKEN, "").replace(EOS_TOKEN, "")
return (preds,)


class LoRAFlorence2(LoRATransformerModel):
load_base_from_roboflow = True
Expand Down Expand Up @@ -59,3 +71,8 @@ def prepare_generation_params(
"input_ids": preprocessed_inputs["input_ids"],
"pixel_values": preprocessed_inputs["pixel_values"],
}

def predict(self, image_in: Image, prompt="", history=None, **kwargs):
(preds,) = super().predict(image_in, prompt, history, **kwargs)
preds = preds.replace(BOS_TOKEN, "").replace(EOS_TOKEN, "")
return (preds,)
7 changes: 5 additions & 2 deletions inference/models/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TransformerModel(RoboflowInferenceModel):
default_dtype = torch.float16
generation_includes_input = False
needs_hf_token = False
skip_special_tokens = True

def __init__(
self, model_id, *args, dtype=None, huggingface_token=HUGGINGFACE_TOKEN, **kwargs
Expand Down Expand Up @@ -119,12 +120,14 @@ def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
preprocessed_inputs=model_inputs
)
generation = self.model.generate(
**prepared_inputs, max_new_tokens=100, do_sample=False
**prepared_inputs, max_new_tokens=1000, do_sample=False
)
generation = generation[0]
if self.generation_includes_input:
generation = generation[input_len:]
decoded = self.processor.decode(generation, skip_special_tokens=True)
decoded = self.processor.decode(
generation, skip_special_tokens=self.skip_special_tokens
)

return (decoded,)

Expand Down
Loading