From e4ffb1ac4d65bef952be668fd35de9a7d3723cf5 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Tue, 12 Nov 2024 17:08:46 +0800
Subject: [PATCH 1/9] feat(minicpm-v): Support MiniCPM-V inference pipeline
---
examples/minicpm_v/inference/inference.py | 25 +
mindone/transformers/__init__.py | 2 +
.../transformers/feature_extraction_utils.py | 607 +++++++
.../transformers/image_processing_utils.py | 864 ++++++++++
mindone/transformers/image_transforms.py | 322 ++++
mindone/transformers/image_utils.py | 95 ++
mindone/transformers/models/__init__.py | 2 +-
.../transformers/models/minicpm_v/__init__.py | 2 +
.../models/minicpm_v/configuration_minicpm.py | 102 ++
.../minicpm_v/image_processing_minicpmv.py | 429 +++++
.../models/minicpm_v/modeling_minicpmv.py | 421 +++++
.../models/minicpm_v/modeling_navit_siglip.py | 1072 ++++++++++++
.../models/minicpm_v/processing_minicpmv.py | 254 +++
.../models/minicpm_v/resampler.py | 834 ++++++++++
.../minicpm_v/tokenization_minicpmv_fast.py | 70 +
mindone/transformers/models/qwen2/__init__.py | 52 +
.../models/qwen2/configuration_qwen2.py | 139 ++
.../models/qwen2/modeling_qwen2.py | 1432 +++++++++++++++++
.../models/qwen2/tokenization_qwen2.py | 337 ++++
.../models/qwen2/tokenization_qwen2_fast.py | 134 ++
mindone/transformers/processing_utils.py | 295 ++++
mindone/transformers/utils/generic.py | 58 +
22 files changed, 7547 insertions(+), 1 deletion(-)
create mode 100644 examples/minicpm_v/inference/inference.py
create mode 100644 mindone/transformers/feature_extraction_utils.py
create mode 100644 mindone/transformers/image_processing_utils.py
create mode 100644 mindone/transformers/image_transforms.py
create mode 100644 mindone/transformers/image_utils.py
create mode 100644 mindone/transformers/models/minicpm_v/__init__.py
create mode 100644 mindone/transformers/models/minicpm_v/configuration_minicpm.py
create mode 100644 mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
create mode 100644 mindone/transformers/models/minicpm_v/modeling_minicpmv.py
create mode 100644 mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
create mode 100644 mindone/transformers/models/minicpm_v/processing_minicpmv.py
create mode 100644 mindone/transformers/models/minicpm_v/resampler.py
create mode 100644 mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
create mode 100644 mindone/transformers/models/qwen2/__init__.py
create mode 100644 mindone/transformers/models/qwen2/configuration_qwen2.py
create mode 100644 mindone/transformers/models/qwen2/modeling_qwen2.py
create mode 100644 mindone/transformers/models/qwen2/tokenization_qwen2.py
create mode 100644 mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
create mode 100644 mindone/transformers/processing_utils.py
create mode 100644 mindone/transformers/utils/generic.py
diff --git a/examples/minicpm_v/inference/inference.py b/examples/minicpm_v/inference/inference.py
new file mode 100644
index 0000000000..30e8373800
--- /dev/null
+++ b/examples/minicpm_v/inference/inference.py
@@ -0,0 +1,25 @@
+import mindspore as ms
+
+from PIL import Image
+from transformers import AutoTokenizer
+from mindone.transformers import MiniCPMV_v2_6
+
+model = MiniCPMV_v2_6.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, attn_implementation='eager', mindspore_dtype=ms.float32)
+
+tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
+
+image = Image.open('airplane.jepg').convert('RGB')
+
+# First Round Chat
+question = "Tell me the model of this aircraft"
+msgs = [{"role": 'user', 'content': [image, question]}]
+answer = model.chat(image=image, msgs=msgs, tokenizer=tokenizer)
+print(answer)
+
+# Second round chat
+# pass history context of multi-turn conversation
+msgs.append({"role": "assistant", "content": [answer]})
+msgs.append({"role": "user", "content": ["Introduce something about Airbus A380."]})
+
+answer = model.chat(image=None, msgs=msgs, tokenizer=tokenizer)
+print(answer)
\ No newline at end of file
diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py
index ce2f862876..ed39ee68b3 100644
--- a/mindone/transformers/__init__.py
+++ b/mindone/transformers/__init__.py
@@ -52,3 +52,5 @@
T5PreTrainedModel,
)
from .models.xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel
+
+from .models.minicpm_v import MiniCPMV_v2_6, MiniCPMVImageProcessor
\ No newline at end of file
diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py
new file mode 100644
index 0000000000..deb703ce75
--- /dev/null
+++ b/mindone/transformers/feature_extraction_utils.py
@@ -0,0 +1,607 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+ Feature extraction saving/loading class for common feature extractors.
+"""
+
+import copy
+import json
+import os
+import warnings
+from collections import UserDict
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+
+FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
+from .utils.generic import TensorType
+from transformers.utils import (
+ cached_file,
+ download_url,
+ is_numpy_array,
+ is_offline_mode,
+ is_remote_url,
+ logging,
+ requires_backends,
+)
+
+import mindspore
+from mindspore import ops
+
+# if is_mindspore_available():
+# import mindspore
+# from mindspore import ops
+
+
+logger = logging.get_logger(__name__)
+
+PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
+
+
+class BatchFeature(UserDict):
+ r"""
+ Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`, *optional*):
+ Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
+ etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
+ super().__init__(data)
+ self.convert_to_tensors(tensor_type=tensor_type)
+
+ def __getitem__(self, item: str) -> Union[Any]:
+ """
+ If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
+ etc.).
+ """
+ if isinstance(item, str):
+ return self.data[item]
+ raise KeyError("Indexing with integers is not available when using Python based feature extractors")
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError as exc:
+ raise AttributeError from exc
+
+ def __getstate__(self):
+ return {"data": self.data}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.keys
+ def keys(self):
+ return self.data.keys()
+
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.values
+ def values(self):
+ return self.data.values()
+
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.items
+ def items(self):
+ return self.data.items()
+
+ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ if tensor_type is None:
+ return None, None
+
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.MINDSPORE:
+ # if not is_mindspore_available():
+ # raise ImportError("Unable to convert output to MindSpore tensors format, MindSpore is not installed.")
+
+ def as_tensor(value):
+ if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
+ value = np.array(value)
+ return mindspore.tensor(value)
+
+ is_tensor = ops.is_tensor
+ else:
+ def as_tensor(value, dtype=None):
+ if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
+ value_lens = [len(val) for val in value]
+ if len(set(value_lens)) > 1 and dtype is None:
+ # we have a ragged list so handle explicitly
+ value = as_tensor([np.asarray(val) for val in value], dtype=object)
+ elif isinstance(value, mindspore.Tensor):
+ return value.asnumpy()
+ return np.asarray(value, dtype=dtype)
+
+ is_tensor = is_numpy_array
+ return is_tensor, as_tensor
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ try:
+ if not is_tensor(value):
+ tensor = as_tensor(value)
+ self[key] = tensor
+ except Exception as exc: # noqa E722
+ if key == "overflowing_values":
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") from exc
+ raise ValueError(
+ "Unable to create tensor, you should probably activate padding "
+ "with 'padding=True' to have batched tensors with the same length."
+ ) from exc
+ return self
+
+ def to(self, *args, **kwargs) -> "BatchFeature":
+ """
+ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
+ different `dtypes` and sending the `BatchFeature` to a different `device`.
+
+ Args:
+ args (`Tuple`):
+ Will be passed to the `to(...)` function of the tensors.
+ kwargs (`Dict`, *optional*):
+ Will be passed to the `to(...)` function of the tensors.
+
+ Returns:
+ [`BatchFeature`]: The same instance after modification.
+ """
+ requires_backends(self, ["mindspore"])
+
+ new_data = {}
+ # Check if the args are a device or a dtype
+ if len(args) > 0:
+ # device should be always the first argument
+ arg = args[0]
+ if isinstance(arg, mindspore._c_expression.typing.Type):
+ # The first argument is a dtype
+ pass
+ else:
+ # it's something else
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+ for k, v in self.items():
+ # check if v is a floating point
+ if ops.is_floating_point(v):
+ # cast and send to device
+ new_data[k] = v.to(*args, **kwargs)
+ else:
+ new_data[k] = v
+ self.data = new_data
+ return self
+
+
+class FeatureExtractionMixin():
+ """
+ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
+ derived class of [`SequenceFeatureExtractor`].
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ hf-mirror.com. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
+ namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
+ if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file
+ exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on hf-mirror.com, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
+ # derived class: *Wav2Vec2FeatureExtractor*
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h"
+ ) # Download feature_extraction_config from hf-mirror.com and cache.
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
+ )
+ assert feature_extractor.return_attention_mask is False
+ feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
+ )
+ assert feature_extractor.return_attention_mask is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(feature_extractor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
+ """
+ Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token", None) is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
+
+ self.to_json_file(output_feature_extractor_file)
+ logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
+
+
+ return [output_feature_extractor_file]
+
+ @classmethod
+ def get_feature_extractor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ revision = kwargs.pop('revision', 'main')
+
+ user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_feature_extractor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ feature_extractor_file = pretrained_model_name_or_path
+ resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
+ else:
+ feature_extractor_file = FEATURE_EXTRACTOR_NAME
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_feature_extractor_file = cached_file(
+ pretrained_model_name_or_path,
+ feature_extractor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ user_agent=user_agent,
+ revision=revision,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception as exc:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://hf-mirror.com/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
+ ) from exc
+
+ try:
+ # Load feature_extractor dict
+ with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ feature_extractor_dict = json.loads(text)
+
+ except json.JSONDecodeError as exc:
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
+ ) from exc
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_feature_extractor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
+ )
+
+
+ return feature_extractor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
+ """
+ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
+ parameters.
+
+ Args:
+ feature_extractor_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the feature extractor object.
+
+ Returns:
+ [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
+ parameters.
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ feature_extractor = cls(**feature_extractor_dict)
+
+ # Update feature_extractor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(feature_extractor, key):
+ setattr(feature_extractor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Feature extractor {feature_extractor}")
+ if return_unused_kwargs:
+ return feature_extractor, kwargs
+ return feature_extractor
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["feature_extractor_type"] = self.__class__.__name__
+ if "mel_filters" in output:
+ del output["mel_filters"]
+ if "window" in output:
+ del output["window"]
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
+ """
+ Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
+ a JSON file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
+ object instantiated from that JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ feature_extractor_dict = json.loads(text)
+ return cls(**feature_extractor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this feature_extractor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
+ """
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+ in the library are already mapped with `AutoFeatureExtractor`.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
+ The auto class to register this new feature extractor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import mindnlp.transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
diff --git a/mindone/transformers/image_processing_utils.py b/mindone/transformers/image_processing_utils.py
new file mode 100644
index 0000000000..4f2232db16
--- /dev/null
+++ b/mindone/transformers/image_processing_utils.py
@@ -0,0 +1,864 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""image processing utils"""
+import copy
+import json
+import os
+import warnings
+from io import BytesIO
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import requests
+
+FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
+IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
+
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
+from .image_transforms import center_crop, normalize, rescale
+from .image_utils import ChannelDimension
+from transformers.utils import cached_file, download_url, is_offline_mode, is_remote_url, is_vision_available, logging
+
+if is_vision_available():
+ from PIL import Image
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
+# We override the class string here, but logic is the same.
+class BatchFeature(BaseBatchFeature):
+ r"""
+ Holds the output of the image processor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+# TODO: (Amy) - factor out the common parts of this and the feature extractor
+class ImageProcessingMixin:
+ """
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
+ # `XXXImageProcessor`, this attribute and its value are misleading.
+ kwargs.pop("feature_extractor_type", None)
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
+ hf-mirror.com.
+ - a path to a *directory* containing a image processor file saved using the
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved image processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
+ they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file
+ exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on hf-mirror.com, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final image processor object. If `True`, then this
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on hf-mirror.com, you can
+ specify the folder name here.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are image processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
+
+ Example:
+ ```python
+ >>> # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
+ >>> # derived class: *CLIPImageProcessor*
+ >>> image_processor = CLIPImageProcessor.from_pretrained(
+ >>> "openai/clip-vit-base-patch32"
+ >>> ) # Download image_processing_config from hf-mirror.com and cache.
+ >>> image_processor = CLIPImageProcessor.from_pretrained(
+ >>> "./test/saved_model/"
+ >>> ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ >>> image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ >>> image_processor = CLIPImageProcessor.from_pretrained(
+ >>> "openai/clip-vit-base-patch32", do_normalize=False, foo=False
+ >>> )
+ >>> assert image_processor.do_normalize is False
+ >>> image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
+ >>> "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
+ >>> )
+ >>> assert image_processor.do_normalize is False
+ >>> assert unused_kwargs == {"foo": False}
+ ```
+ """
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(image_processor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token", None) is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
+
+ self.to_json_file(output_image_processor_file)
+ logger.info(f"Image processor saved in {output_image_processor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_image_processor_file]
+
+ @classmethod
+ def get_image_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on hf-mirror.com, you can
+ specify the folder name here.
+
+ Returns:
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ subfolder = kwargs.pop("subfolder", "")
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_image_processor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ image_processor_file = pretrained_model_name_or_path
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
+ else:
+ image_processor_file = IMAGE_PROCESSOR_NAME
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_image_processor_file = cached_file(
+ pretrained_model_name_or_path,
+ image_processor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ subfolder=subfolder,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception as e:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://hf-mirror.com/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {IMAGE_PROCESSOR_NAME} file"
+ ) from e
+
+ try:
+ # Load image_processor dict
+ with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+
+ except json.JSONDecodeError as e:
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
+ ) from e
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
+ )
+
+ return image_processor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
+
+ Args:
+ image_processor_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the image processor object.
+
+ Returns:
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
+ parameters.
+ """
+ image_processor_dict = image_processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
+ if "size" in kwargs and "size" in image_processor_dict:
+ image_processor_dict["size"] = kwargs.pop("size")
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+ image_processor = cls(**image_processor_dict)
+
+ # Update image_processor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(image_processor, key):
+ setattr(image_processor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Image processor {image_processor}")
+ if return_unused_kwargs:
+ return image_processor, kwargs
+ else:
+ return image_processor
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["image_processor_type"] = self.__class__.__name__
+
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
+ """
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
+ file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
+ instantiated from that JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+ return cls(**image_processor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ """
+ __repr__
+
+ This method returns a string representation of the ImageProcessingMixin object.
+
+ Args:
+ self (ImageProcessingMixin): The instance of the ImageProcessingMixin class.
+ This parameter is used to reference the current instance of the ImageProcessingMixin class.
+
+ Returns:
+ None: This method does not return any value explicitly, as it returns a string representation of the object.
+
+ Raises:
+ None.
+ """
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
+ in the library are already mapped with `AutoImageProcessor `.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
+ The auto class to register this new image processor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import mindnlp.transformers.models.auto as auto_module
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
+ """
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
+ " Safari/537.36"
+ )
+ }
+ if isinstance(image_url_or_urls, list):
+ return [self.fetch_images(x) for x in image_url_or_urls]
+ elif isinstance(image_url_or_urls, str):
+ response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=10)
+ response.raise_for_status()
+ return Image.open(BytesIO(response.content))
+ else:
+ raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
+
+
+class BaseImageProcessor(ImageProcessingMixin):
+
+ """
+ Represents a base image processor that provides methods for image preprocessing operations such as rescaling,
+ normalization, and center cropping.
+
+ This class inherits from ImageProcessingMixin and serves as a template for concrete image processor implementations.
+ Concrete image processors must implement their own preprocess method.
+
+ Attributes:
+ Inherits all attributes from ImageProcessingMixin.
+
+ Methods:
+ __call__(self, images, **kwargs) -> BatchFeature: Preprocess an image or a batch of images.
+ preprocess(self, images, **kwargs) -> BatchFeature: Abstract method to be implemented by concrete image processors.
+ rescale(self, image, scale, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Rescale an image by a scale factor.
+ normalize(self, image, mean, std, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Normalize an image using mean and standard deviation.
+ center_crop(self, image, size, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Center crop an image to a specified size.
+ """
+ def __call__(self, images, **kwargs) -> BatchFeature:
+ """Preprocess an image or a batch of images."""
+ return self.preprocess(images, **kwargs)
+
+ def preprocess(self, images, **kwargs) -> BatchFeature:
+ """
+ Preprocess the given images using the implemented image processor.
+
+ Args:
+ self (BaseImageProcessor): An instance of the BaseImageProcessor class.
+ images (list): A list of images to be preprocessed.
+
+ Returns:
+ BatchFeature: The preprocessed images as a BatchFeature object.
+
+ Raises:
+ NotImplementedError: If the preprocess method is not implemented in the specific image processor.
+
+ """
+ raise NotImplementedError("Each image processor must implement its own preprocess method")
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `Iterable[float]`):
+ Image mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ Image standard deviation to use for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The normalized image.
+ """
+ return normalize(
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
+ return center_crop(
+ image,
+ size=(size["height"], size["width"]),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+
+VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
+
+
+def is_valid_size_dict(size_dict):
+ """
+ Args:
+ size_dict (dict): A dictionary containing size information.
+ The keys in the dictionary should match a predefined set of valid keys.
+
+ Returns:
+ None: Returns None if the size_dict is not a valid size dictionary.
+
+ Raises:
+ None
+ """
+ if not isinstance(size_dict, dict):
+ return False
+
+ size_dict_keys = set(size_dict.keys())
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
+ if size_dict_keys == allowed_keys:
+ return True
+ return False
+
+
+def convert_to_size_dict(
+ size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
+):
+ """
+ Converts a size input into a dictionary representation.
+
+ Args:
+ size (int or tuple/list): The size input to be converted.
+
+ - If an integer is provided and `default_to_square` is True,
+ it creates a square size dictionary with both height and width values set to the given size.
+ - If an integer is provided and `default_to_square` is False,
+ it creates a size dictionary with the shortest edge set to the given size.
+ Optionally, the longest edge can be specified using the `max_size` parameter.
+ - If a tuple or list is provided and `height_width_order` is True,
+ it creates a size dictionary with the first element representing the height and the second element
+ representing the width.
+ - If a tuple or list is provided and `height_width_order` is False,
+ it creates a size dictionary with the first element representing the width and the second element
+ representing the height.
+ - If `size` is None and `max_size` is not None,
+ it creates a size dictionary with the longest edge set to the `max_size`. Note that `default_to_square`
+ must be False in this case.
+
+ max_size (int, optional):
+ The maximum size for the longest edge. Defaults to None.
+
+ - This parameter is only used when `size` is an integer and `default_to_square` is False.
+
+ default_to_square (bool):
+ A flag indicating whether the size dictionary should default to a square shape when `size` is an integer.
+ Defaults to True.
+
+ - If True, the size dictionary will have both height and width values set to the provided size.
+ - If False, the size dictionary will have the shortest edge set to the provided size. Optionally, the longest
+ edge can be specified using the `max_size` parameter.
+
+ height_width_order (bool):
+ A flag indicating whether the height and width order should follow the order of elements in the `size` tuple/list.
+ Defaults to True.
+
+ - If True, the first element of the `size` tuple/list will be considered as the height and the second element
+ as the width.
+ - If False, the first element of the `size` tuple/list will be considered as the width and the second element
+ as the height.
+
+ Returns:
+ dict or None: A dictionary representation of the converted size input.
+ The dictionary will have the following keys:
+
+ - 'height' and 'width' (int): Representing the height and width of the size, respectively.
+ - 'shortest_edge' (int): Representing the size of the shortest edge,
+ when `size` is an integer and `default_to_square` is False.
+ - 'longest_edge' (int): Representing the size of the longest edge,
+ when `size` is an integer and `default_to_square` is False and `max_size` is provided.
+
+ Raises:
+ ValueError: If the input combination is invalid and cannot be converted to a size dictionary.
+
+ """
+ # By default, if size is an int we assume it represents a tuple of (size, size).
+ if isinstance(size, int) and default_to_square:
+ if max_size is not None:
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
+ return {"height": size, "width": size}
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
+ # the shortest edge after resizing.
+ elif isinstance(size, int) and not default_to_square:
+ size_dict = {"shortest_edge": size}
+ if max_size is not None:
+ size_dict["longest_edge"] = max_size
+ return size_dict
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
+ elif isinstance(size, (tuple, list)) and height_width_order:
+ return {"height": size[0], "width": size[1]}
+ elif isinstance(size, (tuple, list)) and not height_width_order:
+ return {"height": size[1], "width": size[0]}
+ elif size is None and max_size is not None:
+ if default_to_square:
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
+ return {"longest_edge": max_size}
+
+ raise ValueError(f"Could not convert size input to size dict: {size}")
+
+
+def get_size_dict(
+ size: Union[int, Iterable[int], Dict[str, int]] = None,
+ max_size: Optional[int] = None,
+ height_width_order: bool = True,
+ default_to_square: bool = True,
+ param_name="size",
+) -> dict:
+ """
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
+ width) or (width, height) format.
+
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
+ size[0]}` if `height_width_order` is `False`.
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
+
+ Args:
+ size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
+ The `size` parameter to be cast into a size dictionary.
+ max_size (`Optional[int]`, *optional*):
+ The `max_size` parameter to be cast into a size dictionary.
+ height_width_order (`bool`, *optional*, defaults to `True`):
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ If `size` is an int, whether to default to a square image or not.
+ """
+ if not isinstance(size, dict):
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
+ logger.info(
+ f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
+ f" Converted to {size_dict}.",
+ )
+ else:
+ size_dict = size
+
+ if not is_valid_size_dict(size_dict):
+ raise ValueError(
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
+ )
+ return size_dict
+
+def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ This is done by calculating the effective and wasted resolution for each possible resolution.
+
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
+
+ Args:
+ original_size (tuple):
+ The original size of the image in the format (height, width).
+ possible_resolutions (list):
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (height, width).
+ """
+ original_height, original_width = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for height, width in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
+ ):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (height, width)
+
+ return best_fit
diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py
new file mode 100644
index 0000000000..1b68f1aad5
--- /dev/null
+++ b/mindone/transformers/image_transforms.py
@@ -0,0 +1,322 @@
+import warnings
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+import PIL
+
+import mindspore
+from mindspore import ops
+
+from .image_utils import (
+ ChannelDimension,
+ ImageInput,
+ get_channel_dimension_axis,
+ get_image_size,
+ infer_channel_dimension_format,
+)
+
+def to_channel_dimension_format(
+ image: np.ndarray,
+ channel_dim: Union[ChannelDimension, str],
+ input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
+) -> np.ndarray:
+ """
+ Converts `image` to the channel dimension format specified by `channel_dim`.
+
+ Args:
+ image (`numpy.ndarray`):
+ The image to have its channel dimension set.
+ channel_dim (`ChannelDimension`):
+ The channel dimension format to use.
+ input_channel_dim (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+
+ Returns:
+ `np.ndarray`: The image with the channel dimension set to `channel_dim`.
+ """
+ if not isinstance(image, np.ndarray):
+ raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ if input_channel_dim is None:
+ input_channel_dim = infer_channel_dimension_format(image)
+
+ target_channel_dim = ChannelDimension(channel_dim)
+ if input_channel_dim == target_channel_dim:
+ return image
+
+ if target_channel_dim == ChannelDimension.FIRST:
+ image = image.transpose((2, 0, 1))
+ elif target_channel_dim == ChannelDimension.LAST:
+ image = image.transpose((1, 2, 0))
+ else:
+ raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
+
+ return image
+
+def _rescale_for_pil_conversion(image):
+ """
+ Detects whether or not the image needs to be rescaled before being converted to a PIL image.
+
+ The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
+ rescaled.
+ """
+ if image.dtype == np.uint8:
+ do_rescale = False
+ elif np.allclose(image, image.astype(int)):
+ if np.all(0 <= image) and np.all(image <= 255):
+ do_rescale = False
+ else:
+ raise ValueError(
+ "The image to be converted to a PIL image contains values outside the range [0, 255], "
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
+ )
+ elif np.all(0 <= image) and np.all(image <= 1):
+ do_rescale = True
+ else:
+ raise ValueError(
+ "The image to be converted to a PIL image contains values outside the range [0, 1], "
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
+ )
+ return do_rescale
+
+
+def to_pil_image(
+ image: Union[np.ndarray, "PIL.Image.Image", "mindspore.Tensor"],
+ do_rescale: Optional[bool] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> "PIL.Image.Image":
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `mindspore.Tensor` or `tf.Tensor`):
+ The image to convert to the `PIL.Image` format.
+ do_rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
+ to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
+ and `False` otherwise.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+
+ Returns:
+ `PIL.Image.Image`: The converted image.
+ """
+
+ if isinstance(image, PIL.Image.Image):
+ return image
+
+ # Convert all tensors to numpy arrays before converting to PIL image
+ if isinstance(image, mindspore.Tensor):
+ image = image.asnumpy()
+ elif not isinstance(image, np.ndarray):
+ raise ValueError("Input image type not supported: {}".format(type(image)))
+
+ # If the channel has been moved to first dim, we put it back at the end.
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
+
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
+
+ # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
+ do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
+
+ if do_rescale:
+ image = rescale(image, 255)
+
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image)
+
+def center_crop(
+ image: np.ndarray,
+ size: Tuple[int, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ return_numpy: Optional[bool] = None,
+) -> np.ndarray:
+ """
+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
+ the size given, it will be padded (so the returned result will always be of size `size`).
+
+ Args:
+ image (`np.ndarray`):
+ The image to crop.
+ size (`Tuple[int, int]`):
+ The target size for the cropped image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+ return_numpy (`bool`, *optional*):
+ Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
+ previous ImageFeatureExtractionMixin method.
+ - Unset: will return the same type as the input image.
+ - `True`: will return a numpy array.
+ - `False`: will return a `PIL.Image.Image` object.
+ Returns:
+ `np.ndarray`: The cropped image.
+ """
+
+ if return_numpy is not None:
+ warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
+
+ return_numpy = True if return_numpy is None else return_numpy
+
+ if not isinstance(image, np.ndarray):
+ raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ if not isinstance(size, Iterable) or len(size) != 2:
+ raise ValueError("size must have 2 elements representing the height and width of the output image")
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ output_data_format = data_format if data_format is not None else input_data_format
+
+ # We perform the crop in (C, H, W) format and then convert to the output format
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
+
+ orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
+ crop_height, crop_width = size
+ crop_height, crop_width = int(crop_height), int(crop_width)
+
+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
+ top = (orig_height - crop_height) // 2
+ bottom = top + crop_height
+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
+ left = (orig_width - crop_width) // 2
+ right = left + crop_width
+
+ # Check if cropped area is within image boundaries
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
+ image = image[..., top:bottom, left:right]
+ image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
+ return image
+
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
+ new_height = max(crop_height, orig_height)
+ new_width = max(crop_width, orig_width)
+ new_shape = image.shape[:-2] + (new_height, new_width)
+ new_image = np.zeros_like(image, shape=new_shape)
+
+ # If the image is too small, pad it with zeros
+ top_pad = (new_height - orig_height) // 2
+ bottom_pad = top_pad + orig_height
+ left_pad = (new_width - orig_width) // 2
+ right_pad = left_pad + orig_width
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
+
+ top += top_pad
+ bottom += top_pad
+ left += left_pad
+ right += left_pad
+
+ new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
+ new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
+
+ if not return_numpy:
+ new_image = to_pil_image(new_image)
+
+ return new_image
+
+def normalize(
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+
+ image = (image - mean) / std
+
+ Args:
+ image (`np.ndarray`):
+ The image to normalize.
+ mean (`float` or `Iterable[float]`):
+ The mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ The standard deviation to use for normalization.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+ """
+ if not isinstance(image, np.ndarray):
+ raise ValueError("image must be a numpy array")
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
+ num_channels = image.shape[channel_axis]
+
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
+ if not np.issubdtype(image.dtype, np.floating):
+ image = image.astype(np.float32)
+
+ if isinstance(mean, Iterable):
+ if len(mean) != num_channels:
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
+ else:
+ mean = [mean] * num_channels
+ mean = np.array(mean, dtype=image.dtype)
+
+ if isinstance(std, Iterable):
+ if len(std) != num_channels:
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
+ else:
+ std = [std] * num_channels
+ std = np.array(std, dtype=image.dtype)
+
+ if input_data_format == ChannelDimension.LAST:
+ image = (image - mean) / std
+ else:
+ image = ((image.T - mean) / std).T
+
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ return image
+
+def rescale(
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[ChannelDimension] = None,
+ dtype: np.dtype = np.float32,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Rescales `image` by `scale`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to rescale.
+ scale (`float`):
+ The scale to use for rescaling the image.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
+ extractors.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ if not isinstance(image, np.ndarray):
+ raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ rescaled_image = image * scale
+ if data_format is not None:
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
+
+ rescaled_image = rescaled_image.astype(dtype)
+
+ return rescaled_image
\ No newline at end of file
diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py
new file mode 100644
index 0000000000..e4facfbd03
--- /dev/null
+++ b/mindone/transformers/image_utils.py
@@ -0,0 +1,95 @@
+import base64
+import os
+from io import BytesIO
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import requests
+from packaging import version
+
+from .utils.generic import ExplicitEnum
+
+class ChannelDimension(ExplicitEnum):
+ FIRST = "channels_first"
+ LAST = "channels_last"
+
+ImageInput = Union[
+ "PIL.Image.Image", np.ndarray, "mindspore.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["mindspore.Tensor"]
+] # noqa
+
+def get_channel_dimension_axis(
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
+) -> int:
+ """
+ Returns the channel dimension axis of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the channel dimension axis of.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ The channel dimension axis of the image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ if input_data_format == ChannelDimension.FIRST:
+ return image.ndim - 3
+ elif input_data_format == ChannelDimension.LAST:
+ return image.ndim - 1
+ raise ValueError(f"Unsupported data format: {input_data_format}")
+
+def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
+ """
+ Returns the (height, width) dimensions of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the dimensions of.
+ channel_dim (`ChannelDimension`, *optional*):
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ A tuple of the image's height and width.
+ """
+ if channel_dim is None:
+ channel_dim = infer_channel_dimension_format(image)
+
+ if channel_dim == ChannelDimension.FIRST:
+ return image.shape[-2], image.shape[-1]
+ elif channel_dim == ChannelDimension.LAST:
+ return image.shape[-3], image.shape[-2]
+ else:
+ raise ValueError(f"Unsupported data format: {channel_dim}")
+
+def infer_channel_dimension_format(
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
+) -> ChannelDimension:
+ """
+ Infers the channel dimension format of `image`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to infer the channel dimension of.
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
+ The number of channels of the image.
+
+ Returns:
+ The channel dimension of the image.
+ """
+ num_channels = num_channels if num_channels is not None else (1, 3)
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
+
+ if image.ndim == 3:
+ first_dim, last_dim = 0, 2
+ elif image.ndim == 4:
+ first_dim, last_dim = 1, 3
+ else:
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
+
+ if image.shape[first_dim] in num_channels:
+ return ChannelDimension.FIRST
+ elif image.shape[last_dim] in num_channels:
+ return ChannelDimension.LAST
+ raise ValueError("Unable to infer channel dimension format")
\ No newline at end of file
diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py
index 089c3e47ad..bab2bf947a 100644
--- a/mindone/transformers/models/__init__.py
+++ b/mindone/transformers/models/__init__.py
@@ -1 +1 @@
-from . import bert, bit, blip_2, clip, dpt, gemma, t5, xlm_roberta
+from . import bit, blip_2, clip, dpt, minicpm_v, t5, xlm_roberta
diff --git a/mindone/transformers/models/minicpm_v/__init__.py b/mindone/transformers/models/minicpm_v/__init__.py
new file mode 100644
index 0000000000..c973fe341f
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_minicpmv import MiniCPMV_v2_6
+from .image_processing_minicpmv import MiniCPMVImageProcessor
\ No newline at end of file
diff --git a/mindone/transformers/models/minicpm_v/configuration_minicpm.py b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
new file mode 100644
index 0000000000..063cfee91b
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
@@ -0,0 +1,102 @@
+# coding=utf-8
+""" MiniCPMV model configuration"""
+
+import os
+from typing import Union
+
+from transformers import PretrainedConfig, Qwen2Config
+from transformers.utils import logging
+
+from .modeling_navit_siglip import SiglipVisionConfig
+
+logger = logging.get_logger(__name__)
+
+
+class MiniCPMVSliceConfig(PretrainedConfig):
+ model_type = "minicpmv"
+
+ def __init__(
+ self,
+ patch_size=14,
+ max_slice_nums=9,
+ scale_resolution=448,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+ self.max_slice_nums = max_slice_nums
+ self.scale_resolution = scale_resolution
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if config_dict.get("model_type") == "minicpmv":
+ config_dict = config_dict["slice_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+
+class MiniCPMVConfig(Qwen2Config):
+ model_type = "minicpmv"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ default_vision_config = {
+ "hidden_size": 1152,
+ "image_size": 980,
+ "intermediate_size": 4304,
+ "model_type": "siglip",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 27,
+ "patch_size": 14,
+ "attn_implementation": "flash_attention"
+ }
+
+ def __init__(
+ self,
+ use_cache=True,
+ query_num=64,
+ image_size=448,
+ drop_vision_last_layer=True,
+ batch_vision_input=True,
+ slice_config=None,
+ vision_config=None,
+ use_image_id=True,
+ vision_batch_size=16,
+ **kwargs,
+ ):
+ self.use_cache = use_cache
+ self.query_num = query_num
+ self.image_size = image_size
+ self.drop_vision_last_layer = drop_vision_last_layer
+ self.batch_vision_input = batch_vision_input
+ self.use_image_id = use_image_id
+ self.vision_batch_size = vision_batch_size
+
+ if slice_config is None:
+ self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
+ else:
+ self.slice_config = MiniCPMVSliceConfig(**slice_config)
+ self.slice_mode = True
+
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
+ if vision_config is None:
+ self.vision_config = SiglipVisionConfig(**self.default_vision_config)
+ logger.info("vision_config is None, using default vision config")
+ elif isinstance(vision_config, dict):
+ self.vision_config = SiglipVisionConfig(**vision_config)
+ elif isinstance(vision_config, SiglipVisionConfig):
+ self.vision_config = vision_config
+
+ self.patch_size = self.vision_config.patch_size
+
+ super().__init__(**kwargs)
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
new file mode 100644
index 0000000000..7626964f1b
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -0,0 +1,429 @@
+import math
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import PIL.Image
+import PIL.ImageSequence
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from PIL import Image
+from transformers.image_transforms import to_channel_dimension_format
+from transformers.image_utils import (
+ ChannelDimension,
+ ImageInput,
+ infer_channel_dimension_format,
+ is_batched,
+ is_torch_tensor,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, nn, ops
+
+
+def recursive_converter(converter, value):
+ if isinstance(value, list):
+ new_value = []
+ for v in value:
+ new_value += [recursive_converter(converter, v)]
+ return new_value
+ else:
+ return converter(value)
+
+
+class MiniCPMVBatchFeature(BatchFeature):
+ r"""
+ Extend from BatchFeature for supporting various image size
+ """
+
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
+ super().__init__(data)
+ self.convert_to_tensors(tensor_type=tensor_type)
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
+
+ def converter(value):
+ try:
+ if not is_tensor(value):
+ tensor = as_tensor(value)
+ return tensor
+ except: # noqa E722
+ if key == "overflowing_values":
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
+ raise ValueError(
+ "Unable to create tensor, you should probably activate padding "
+ "with 'padding=True' to have batched tensors with the same length."
+ )
+
+ for key, value in self.items():
+ self[key] = recursive_converter(converter, value)
+ return self
+
+ def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
+ requires_backends(self, ["torch"])
+ import torch
+
+ def cast_tensor(v):
+ # check if v is a floating point
+ if torch.is_floating_point(v):
+ # cast and send to device
+ return v.to(*args, **kwargs)
+ elif device is not None:
+ return v.to(device=device)
+ else:
+ return v
+
+ new_data = {}
+ device = kwargs.get("device")
+ # Check if the args are a device or a dtype
+ if device is None and len(args) > 0:
+ # device should be always the first argument
+ arg = args[0]
+ if is_torch_dtype(arg):
+ # The first argument is a dtype
+ pass
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
+ device = arg
+ else:
+ # it's something else
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+ for k, v in self.items():
+ new_data[k] = recursive_converter(cast_tensor, v)
+ self.data = new_data
+ return self
+
+
+class MiniCPMVImageProcessor(BaseImageProcessor):
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ max_slice_nums=9,
+ scale_resolution=448,
+ patch_size=14,
+ **kwargs):
+ super().__init__(**kwargs)
+ self.max_slice_nums = max_slice_nums
+ self.scale_resolution = scale_resolution
+ self.patch_size = patch_size
+ self.use_image_id = kwargs.pop("use_image_id", False)
+ self.image_feature_size = kwargs.pop("image_feature_size", 64)
+ self.im_start_token = kwargs.pop("im_start", "")
+ self.im_end_token = kwargs.pop("im_end", "")
+ self.slice_start_token = kwargs.pop("slice_start", "")
+ self.slice_end_token = kwargs.pop("slice_end", "")
+ self.unk_token = kwargs.pop("unk", "")
+ self.im_id_start = kwargs.pop("im_id_start", "")
+ self.im_id_end = kwargs.pop("im_id_end", "")
+ self.slice_mode = kwargs.pop("slice_mode", True)
+ self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
+ self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
+ self.version = kwargs.pop("version", 2.0)
+
+ def ensure_divide(self, length, patch_size):
+ return max(round(length / patch_size) * patch_size, patch_size)
+
+ def find_best_resize(self,
+ original_size,
+ scale_resolution,
+ patch_size,
+ allow_upscale=False):
+ width, height = original_size
+ if (width * height >
+ scale_resolution * scale_resolution) or allow_upscale:
+ r = width / height
+ height = int(scale_resolution / math.sqrt(r))
+ width = int(height * r)
+ best_width = self.ensure_divide(width, patch_size)
+ best_height = self.ensure_divide(height, patch_size)
+ return (best_width, best_height)
+
+ def get_refine_size(self,
+ original_size,
+ grid,
+ scale_resolution,
+ patch_size,
+ allow_upscale=False):
+ width, height = original_size
+ grid_x, grid_y = grid
+
+ refine_width = self.ensure_divide(width, grid_x)
+ refine_height = self.ensure_divide(height, grid_y)
+
+ grid_width = refine_width / grid_x
+ grid_height = refine_height / grid_y
+
+ best_grid_size = self.find_best_resize((grid_width, grid_height),
+ scale_resolution,
+ patch_size,
+ allow_upscale=allow_upscale)
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
+ return refine_size
+
+ def split_to_patches(self, image, grid):
+ patches = []
+ width, height = image.size
+ grid_x = int(width / grid[0])
+ grid_y = int(height / grid[1])
+ for i in range(0, height, grid_y):
+ images = []
+ for j in range(0, width, grid_x):
+ box = (j, i, j + grid_x, i + grid_y)
+ patch = image.crop(box)
+ images.append(patch)
+ patches.append(images)
+ return patches
+
+ def slice_image(
+ self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
+ ):
+ original_size = image.size
+ source_image = None
+ best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
+ patches = []
+
+ if best_grid is None:
+ # dont need to slice, upsample
+ best_size = self.find_best_resize(
+ original_size, scale_resolution, patch_size, allow_upscale=True
+ )
+ source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
+ else:
+ # source image, down-sampling and ensure divided by patch_size
+ best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
+ source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
+ refine_size = self.get_refine_size(
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
+ )
+ refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
+ patches = self.split_to_patches(refine_image, best_grid)
+
+ return source_image, patches, best_grid
+
+ def get_grid_placeholder(self, grid):
+ if grid is None:
+ return ""
+ slice_image_placeholder = (
+ self.slice_start_token
+ + self.unk_token * self.image_feature_size
+ + self.slice_end_token
+ )
+
+ cols = grid[0]
+ rows = grid[1]
+ slices = []
+ for i in range(rows):
+ lines = []
+ for j in range(cols):
+ lines.append(slice_image_placeholder)
+ slices.append("".join(lines))
+
+ slice_placeholder = "\n".join(slices)
+ return slice_placeholder
+
+ def get_image_id_placeholder(self, idx=0):
+ return f"{self.im_id_start}{idx}{self.im_id_end}"
+
+ def get_sliced_images(self, image, max_slice_nums=None):
+ slice_images = []
+
+ if not self.slice_mode:
+ return [image]
+
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
+ assert max_slice_nums > 0
+ source_image, patches, sliced_grid = self.slice_image(
+ image,
+ max_slice_nums, # default: 9
+ self.scale_resolution, # default: 448
+ self.patch_size # default: 14
+ )
+
+ slice_images.append(source_image)
+ if len(patches) > 0:
+ for i in range(len(patches)):
+ for j in range(len(patches[0])):
+ slice_images.append(patches[i][j])
+ return slice_images
+
+ def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
+ original_width, original_height = image_size
+ log_ratio = math.log(original_width / original_height)
+ ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
+ multiple = min(math.ceil(ratio), max_slice_nums)
+ if multiple <= 1 or nerver_split:
+ return None
+ candidate_split_grids_nums = []
+ for i in [multiple - 1, multiple, multiple + 1]:
+ if i == 1 or i > max_slice_nums:
+ continue
+ candidate_split_grids_nums.append(i)
+
+ candidate_grids = []
+ for split_grids_nums in candidate_split_grids_nums:
+ m = 1
+ while m <= split_grids_nums:
+ if split_grids_nums % m == 0:
+ candidate_grids.append([m, split_grids_nums // m])
+ m += 1
+
+ best_grid = [1, 1]
+ min_error = float("inf")
+ for grid in candidate_grids:
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
+ if error < min_error:
+ best_grid = grid
+ min_error = error
+
+ return best_grid
+
+ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
+ assert max_slice_nums > 0
+ grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
+
+ image_placeholder = (
+ self.im_start_token
+ + self.unk_token * self.image_feature_size
+ + self.im_end_token
+ )
+ use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
+ if use_image_id:
+ final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
+ else:
+ final_placeholder = image_placeholder
+
+ if self.slice_mode:
+ final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
+ return final_placeholder
+
+ def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
+ The image to convert to the PIL Image format.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
+ default to `True` if the image type is a floating type, `False` otherwise.
+ """
+ if isinstance(image, PIL.Image.Image):
+ return image
+ if is_torch_tensor(image):
+ image = image.numpy()
+
+ if isinstance(image, np.ndarray):
+ if rescale is None:
+ # rescale default to the array being of floating type.
+ rescale = isinstance(image.flat[0], np.floating)
+ # If the channel as been moved to first dim, we put it back at the end.
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ image = image.transpose(1, 2, 0)
+ if rescale:
+ image = image * 255
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image)
+ return image
+
+ def reshape_by_patch(self, image):
+ """
+ :param image: shape [3, H, W]
+ :param patch_size:
+ :return: [3, patch_size, HW/patch_size]
+ """
+ image = ms.Tensor(image)
+ patch_size = self.patch_size
+
+ c = image.shape[0]
+ h = image.shape[1]
+ w = image.shape[2]
+ image = image.reshape(1, c, h, w)
+
+ patches = ops.unfold(
+ image,
+ (patch_size, patch_size),
+ stride=(patch_size, patch_size)
+ )
+
+ image = image.squeeze(axis=0)
+
+ patches = patches.reshape(image.shape[0], patch_size, patch_size, -1)
+ patches = patches.permute(0, 1, 3, 2).reshape(image.shape[0], patch_size, -1)
+ return patches.numpy()
+
+ def preprocess(
+ self,
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
+ max_slice_nums: int = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> MiniCPMVBatchFeature:
+ if isinstance(images, Image.Image):
+ images_list = [[images]]
+ elif isinstance(images[0], Image.Image):
+ images_list = [images]
+ else:
+ images_list = images
+
+ new_images_list = []
+ image_sizes_list = []
+ tgt_sizes_list = []
+
+ for _images in images_list:
+ if _images is None or len(_images) == 0:
+ new_images_list.append([])
+ image_sizes_list.append([])
+ tgt_sizes_list.append([])
+ continue
+ if not valid_images(_images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ _images = [self.to_pil_image(image).convert("RGB") for image in _images]
+ input_data_format = infer_channel_dimension_format(np.array(_images[0]))
+
+ new_images = []
+ image_sizes = [image.size for image in _images]
+ tgt_sizes = []
+ for image in _images:
+ image_patches = self.get_sliced_images(image, max_slice_nums)
+ image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
+ image_patches = [
+ self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
+ for image in image_patches
+ ]
+ image_patches = [
+ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
+ for image in image_patches
+ ]
+ for slice_image in image_patches:
+ new_images.append(self.reshape_by_patch(slice_image))
+ tgt_sizes.append(
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
+
+ if tgt_sizes:
+ tgt_sizes = np.vstack(tgt_sizes)
+
+ new_images_list.append(new_images)
+ image_sizes_list.append(image_sizes)
+ tgt_sizes_list.append(tgt_sizes)
+ return MiniCPMVBatchFeature(
+ data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
+ tensor_type=return_tensors
+ )
+
+
+# AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
new file mode 100644
index 0000000000..43d35d0990
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -0,0 +1,421 @@
+import json
+import math
+from copy import deepcopy
+from threading import Thread
+from typing import List, Optional
+
+from transformers import TextIteratorStreamer
+from PIL import Image
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, nn, ops
+
+from ..qwen2 import Qwen2ForCausalLM, Qwen2PreTrainedModel
+from .configuration_minicpm import MiniCPMVConfig
+from .modeling_navit_siglip import SiglipVisionTransformer
+from .processing_minicpmv import MiniCPMVProcessor
+from .image_processing_minicpmv import MiniCPMVImageProcessor
+from .resampler import Resampler
+from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
+
+from mindspore import _no_grad
+
+class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
+ config_class = MiniCPMVConfig
+
+
+class MiniCPMV_v2_6(MiniCPMVPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.llm = Qwen2ForCausalLM(config)
+ self.vpm = self.init_vision_module()
+ self.vision_dim = self.vpm.embed_dim
+ self.embed_dim = self.llm.config.hidden_size
+ self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
+ self.processor = None
+
+ self.terminators = ['<|im_end|>', '<|endoftext|>']
+
+ def init_vision_module(self):
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
+ if self.config._attn_implementation == 'flash_attention_2':
+ self.config.vision_config._attn_implementation = 'flash_attention_2'
+ else:
+ # not suport sdpa
+ self.config.vision_config._attn_implementation = 'eager'
+ model = SiglipVisionTransformer(self.config.vision_config)
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+
+ setattr(model, 'embed_dim', model.embeddings.embed_dim)
+ setattr(model, 'patch_size', model.embeddings.patch_size)
+
+ return model
+
+ def init_resampler(self, embed_dim, vision_dim):
+ return Resampler(
+ num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ adaptive=True
+ )
+
+ def get_input_embeddings(self):
+ return self.llm.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.llm.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.llm.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.llm.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.llm = decoder
+
+ def get_decoder(self):
+ return self.llm
+
+ def get_vllm_embedding(self, data):
+ if 'vision_hidden_states' not in data:
+ dtype = self.llm.model.embed_tokens.embedding_table.dtype
+ device = None
+ tgt_sizes = data['tgt_sizes']
+ pixel_values_list = data['pixel_values']
+ vision_hidden_states = []
+ all_pixel_values = []
+ img_cnt = []
+ for pixel_values in pixel_values_list:
+ img_cnt.append(len(pixel_values))
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
+
+ # exist image
+ if all_pixel_values:
+ tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, ms.Tensor)]
+ tgt_sizes = ops.vstack(tgt_sizes).astype(ms.int32)
+
+ max_patches = ops.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])[0].asnumpy()
+
+ # FIXME all_pixel_values
+ # # all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
+ # padding_value=0.0)
+
+ max_length_h = max([i.shape[0] for i in all_pixel_values])
+ max_length_w = max([i.shape[1] for i in all_pixel_values])
+ for i in range(len(all_pixel_values)):
+ if all_pixel_values[i].shape[0] < max_length_h or all_pixel_values[i].shape[1] < max_length_w:
+ all_pixel_values[i] = ops.pad(all_pixel_values[i], (0, max_length_w - all_pixel_values[i].shape[1], 0, max_length_h - all_pixel_values[i].shape[0]), value=0.0)
+ all_pixel_values = ops.stack(all_pixel_values)
+
+ B, L, _ = all_pixel_values.shape
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+
+ patch_attn_mask = ops.zeros(Tensor((B, 1, int(max_patches))), dtype=ms.bool_)
+ for i in range(B):
+ patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+
+ vision_batch_size = self.config.vision_batch_size
+ all_pixel_values = all_pixel_values.astype(dtype)
+ if B > vision_batch_size:
+ hs = []
+ for i in range(0, B, vision_batch_size):
+ start_idx = i
+ end_idx = i + vision_batch_size
+ tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
+ hs.append(tmp_hs)
+ vision_embedding = ops.cat(hs, axis=0)
+ else:
+ vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
+
+ start = 0
+ for pixel_values in pixel_values_list:
+ img_cnt = len(pixel_values)
+ if img_cnt > 0:
+ vision_hidden_states.append(vision_embedding[start: start + img_cnt])
+ start += img_cnt
+ else:
+ vision_hidden_states.append([])
+ else: # no image
+ if self.training:
+ dummy_image = ops.zeros(
+ (1, 3, 224, 224),
+ dtype=dtype
+ )
+ tgt_sizes = ms.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).astype(ms.int32)
+ dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
+ else:
+ dummy_feature = []
+ for _ in range(len(pixel_values_list)):
+ vision_hidden_states.append(dummy_feature)
+
+ else:
+ vision_hidden_states = data['vision_hidden_states']
+
+ if hasattr(self.llm.config, 'scale_emb'):
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
+ else:
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
+
+ vision_hidden_states = [i.astype(vllm_embedding.dtype) if isinstance(
+ i, ms.Tensor) else i for i in vision_hidden_states]
+
+ # bs = len(data['input_ids'])
+ # for i in range(bs):
+ # cur_vs_hs = vision_hidden_states[i]
+ # if len(cur_vs_hs) > 0:
+ # cur_vllm_emb = vllm_embedding[i]
+ # cur_image_bound = data['image_bound'][i]
+ # if len(cur_image_bound) > 0:
+ # image_indices = ops.stack(
+ # [ops.arange(r[0], r[1], dtype=ms.int64) for r in cur_image_bound]
+ # )
+ #
+ # cur_vllm_emb.scatter(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
+ # cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
+ # elif self.training:
+ # cur_vllm_emb += cur_vs_hs[0].mean() * 0
+
+ return vllm_embedding, vision_hidden_states
+
+ def construct(self, data, **kwargs):
+ vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
+ position_ids = data["position_ids"]
+ if position_ids.dtype != ms.int64:
+ position_ids = position_ids.long()
+
+ with _no_grad():
+ return self.llm(
+ input_ids=None,
+ position_ids=position_ids,
+ inputs_embeds=vllm_embedding,
+ **kwargs
+ )
+
+ def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ output = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ pad_token_id=0,
+ eos_token_id=terminators,
+ attention_mask=attention_mask,
+ **kwargs
+ )
+ if decode_text:
+ return self._decode_text(output, tokenizer)
+ return output
+
+ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
+ generation_kwargs = {
+ 'inputs_embeds': inputs_embeds,
+ 'pad_token_id': 0,
+ 'eos_token_id': terminators,
+ 'streamer': streamer
+ }
+ generation_kwargs.update(kwargs)
+
+ thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
+ thread.start()
+
+ return streamer
+
+ def _decode_text(self, result_ids, tokenizer):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ result_text = []
+ for result in result_ids:
+ result = result[result != 0]
+ if result[0] == tokenizer.bos_id:
+ result = result[1:]
+ if result[-1] in terminators:
+ result = result[:-1]
+ result_text.append(tokenizer.decode(result).strip())
+ return result_text
+
+ def generate(
+ self,
+ input_ids=None,
+ pixel_values=None,
+ tgt_sizes=None,
+ image_bound=None,
+ attention_mask=None,
+ tokenizer=None,
+ vision_hidden_states=None,
+ return_vision_hidden_states=False,
+ stream=False,
+ decode_text=False,
+ **kwargs
+ ):
+ assert input_ids is not None
+ assert len(input_ids) == len(pixel_values)
+
+ model_inputs = {
+ "input_ids": input_ids,
+ "image_bound": image_bound,
+ }
+
+ if vision_hidden_states is None:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs['tgt_sizes'] = tgt_sizes
+ else:
+ model_inputs["vision_hidden_states"] = vision_hidden_states
+
+ with ms._no_grad():
+ (
+ model_inputs["inputs_embeds"],
+ vision_hidden_states,
+ ) = self.get_vllm_embedding(model_inputs)
+
+ if stream:
+ result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
+ else:
+ result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
+
+ if return_vision_hidden_states:
+ return result, vision_hidden_states
+
+ return result
+
+ def chat(
+ self,
+ image,
+ msgs,
+ tokenizer,
+ processor=None,
+ vision_hidden_states=None,
+ max_new_tokens=2048,
+ min_new_tokens=0,
+ sampling=True,
+ max_inp_length=8192,
+ system_prompt='',
+ stream=False,
+ max_slice_nums=None,
+ use_image_id=None,
+ **kwargs
+ ):
+ if isinstance(msgs[0], list):
+ batched = True
+ else:
+ batched = False
+ msgs_list = msgs
+ images_list = image
+
+ if batched is False:
+ images_list, msgs_list = [images_list], [msgs_list]
+ else:
+ assert images_list is None, "Please integrate image to msgs when using batch inference."
+ images_list = [None] * len(msgs_list)
+ assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
+
+ image_processor = MiniCPMVImageProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
+
+ if processor is None:
+ if self.processor is None:
+ self.processor = MiniCPMVProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
+ processor = self.processor
+
+ assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+
+ prompts_lists = []
+ input_images_lists = []
+ for image, msgs in zip(images_list, msgs_list):
+ if isinstance(msgs, str):
+ msgs = json.loads(msgs)
+ copy_msgs = deepcopy(msgs)
+
+ assert len(msgs) > 0, "msgs is empty"
+ assert sampling or not stream, "if use stream mode, make sure sampling=True"
+
+ if image is not None and isinstance(copy_msgs[0]["content"], str):
+ copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
+
+ images = []
+ for i, msg in enumerate(copy_msgs):
+ role = msg["role"]
+ content = msg["content"]
+ assert role in ["user", "assistant"]
+ if i == 0:
+ assert role == "user", "The role of first msg should be user"
+ if isinstance(content, str):
+ content = [content]
+ cur_msgs = []
+ for c in content:
+ if isinstance(c, Image.Image):
+ images.append(c)
+ cur_msgs.append("(./)")
+ elif isinstance(c, str):
+ cur_msgs.append(c)
+ msg["content"] = "\n".join(cur_msgs)
+
+ if system_prompt:
+ sys_msg = {'role': 'system', 'content': system_prompt}
+ copy_msgs = [sys_msg] + copy_msgs
+
+ prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
+ input_images_lists.append(images)
+
+ inputs = processor(
+ prompts_lists,
+ input_images_lists,
+ max_slice_nums=max_slice_nums,
+ use_image_id=use_image_id,
+ return_tensors="ms",
+ max_length=max_inp_length,
+ image_processor=image_processor
+ )
+
+ if sampling:
+ generation_config = {
+ "top_p": 0.8,
+ "top_k": 100,
+ "temperature": 0.7,
+ "do_sample": True,
+ "repetition_penalty": 1.05
+ }
+ else:
+ generation_config = {
+ "num_beams": 3,
+ "repetition_penalty": 1.2,
+ }
+
+ if min_new_tokens > 0:
+ generation_config['min_new_tokens'] = min_new_tokens
+
+ generation_config.update(
+ (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
+ )
+
+ inputs.pop("image_sizes")
+ # with torch.inference_mode():
+ res = self.generate(
+ **inputs,
+ tokenizer=tokenizer,
+ max_new_tokens=max_new_tokens,
+ vision_hidden_states=vision_hidden_states,
+ stream=stream,
+ decode_text=True,
+ **generation_config
+ )
+
+ if stream:
+ def stream_gen():
+ for text in res:
+ for term in self.terminators:
+ text = text.replace(term, '')
+ yield text
+ return stream_gen()
+
+ else:
+ if batched:
+ answer = res
+ else:
+ answer = res[0]
+ return answer
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
new file mode 100644
index 0000000000..9b946ad76d
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -0,0 +1,1072 @@
+# coding=utf-8
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Siglip model. """
+# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
+
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import numpy as np
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+import mindspore as ms
+
+# import torch.utils.checkpoint
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore.ops.operations.nn_ops import FlashAttentionScore as FlashAttention
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
+from ...modeling_utils import MSPreTrainedModel
+
+# from torch.nn.init import _calculate_fan_in_and_fan_out
+
+
+logger = logging.get_logger(__name__)
+
+class SiglipVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ Example:
+ ```python
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
+ >>> configuration = SiglipVisionConfig()
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
+ >>> model = SiglipVisionModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "siglip_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=224,
+ patch_size=16,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from SiglipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
+
+SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "google/siglip-base-patch16-224",
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
+]
+
+# if is_flash_attn_2_available():
+# from flash_attn import flash_attn_func, flash_attn_varlen_func
+# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=ms.int32)
+ indices = ops.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, axis=0, dtype=ms.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ if tensor.dtype in [ms.float16, ms.bfloat16]:
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
+ og_dtype = tensor.dtype
+ tensor = tensor.to(ms.float32)
+ tensor.erfinv_()
+ tensor = tensor.to(og_dtype)
+ else:
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ if tensor.dtype == ms.float16:
+ # The `clamp_` op is not (yet?) defined in float16+cpu
+ tensor = tensor.to(ms.float32)
+ tensor.clamp_(min=a, max=b)
+ tensor = tensor.to(ms.float16)
+ else:
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: ms.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
+) -> ms.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsquently scaled and shifted by the mean and std args.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ # FIXME no grad
+ # with ms.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+def _calculate_fan_in_and_fan_out(arr):
+ # 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。
+ shape = arr.shape
+ dimensions = len(shape)
+ if dimensions < 2:
+ raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
+ " 2 dimensions, but got dimensions {}.".format(dimensions))
+ if dimensions == 2: # Linear
+ fan_in = shape[1]
+ fan_out = shape[0]
+ else:
+ num_input_fmaps = shape[1]
+ num_output_fmaps = shape[0]
+ receptive_field_size = 1
+ for i in range(2, dimensions):
+ receptive_field_size *= shape[i]
+ fan_in = num_input_fmaps * receptive_field_size
+ fan_out = num_output_fmaps * receptive_field_size
+ return fan_in, fan_out
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ # FIXME no grad
+ # with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ # FIXME no grad
+ # with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
+class SiglipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[ms.Tensor] = None
+ last_hidden_state: ms.Tensor = None
+ hidden_states: Optional[Tuple[ms.Tensor]] = None
+ attentions: Optional[Tuple[ms.Tensor]] = None
+
+
+class SiglipVisionEmbeddings(nn.Cell):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ pad_mode="valid",
+ has_bias=True,
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def construct(self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor]=None) -> ms.Tensor:
+ batch_size = pixel_values.shape[0]
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(start_dim=2).swapaxes(1, 2)
+
+ max_im_h, max_im_w = pixel_values.shape[2], pixel_values.shape[3]
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = ops.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
+ position_ids = ops.full(
+ size=(
+ batch_size,
+ max_nb_patches_h * max_nb_patches_w,
+ ),
+ fill_value=0,
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ if tgt_sizes is not None:
+ nb_patches_h = tgt_sizes[batch_idx][0]
+ nb_patches_w = tgt_sizes[batch_idx][1]
+ else:
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = ops.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = ops.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = ops.bucketize(fractional_coords_h, boundaries.tolist(), right=True)
+ bucket_coords_w = ops.bucketize(fractional_coords_w, boundaries.tolist(), right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
+
+ position_ids = position_ids
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class SiglipAttention(nn.Cell):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim ** -0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Dense(self.embed_dim, self.embed_dim)
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ output_attentions: Optional[ms.Tensor] = False,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+
+ query_states = ops.mul(query_states, self.scale ** 0.5)
+ key_states = ops.mul(key_states, self.scale ** 0.5)
+
+ attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
+
+ if attn_weights.shape != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.shape}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.shape != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.shape}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = ops.softmax(attn_weights, axis=-1, dtype=ms.float32).to(query_states.dtype)
+ attn_weights = ops.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = ops.matmul(attn_weights, value_states)
+
+ if attn_output.shape != (batch_size, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.shape}"
+ )
+
+ attn_output = attn_output.swapaxes(1, 2)
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# FIXME TBD
+class SiglipFlashAttention2(SiglipAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_causal = False # Hack to make sure we don't use a causal mask
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Tuple[ms.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ # if past_key_value is not None:
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.swapaxes(1, 2)
+ key_states = key_states.swapaxes(1, 2)
+ value_states = value_states.swapaxes(1, 2)
+
+ dropout_rate = self.dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == ms.float32:
+ if ms.is_autocast_enabled():
+ target_dtype = ms.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = ops.arange(
+ batch_size + 1, dtype=ms.int32
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+class SiglipFlashAttention(SiglipAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_causal = False # Hack to make sure we don't use a causal mask
+
+ dropout_rate = self.dropout if self.training else 0.0
+ self.flash_attention = FlashAttention(
+ scale_value=self.head_dim**-0.5,
+ head_num=self.head_dim,
+ input_layout="BSH",
+ keep_prob=1-dropout_rate
+ )
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Tuple[ms.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ # query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ # key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ # value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ # if past_key_value is not None:
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ # query_states = query_states.swapaxes(1, 2)
+ # key_states = key_states.swapaxes(1, 2)
+ # value_states = value_states.swapaxes(1, 2)
+
+ # dropout_rate = self.dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == ms.float32:
+ # if ms.is_autocast_enabled():
+ # target_dtype = ms.get_autocast_gpu_dtype()
+ # # Handle the case where the model is quantized
+ # elif hasattr(self.config, "_pre_quantization_dtype"):
+ # target_dtype = self.config._pre_quantization_dtype
+ # else:
+ # target_dtype = self.q_proj.weight.dtype
+ target_dtype = ms.float16
+
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # implement flash attention
+ attn_output = self.flash_attention(
+ query_states, key_states, value_states, None, None, None, attention_mask
+ )[3]
+
+ # attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Dense(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Dense(config.intermediate_size, config.hidden_size)
+
+ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
+class SiglipEncoderLayer(nn.Cell):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention = config._attn_implementation == "flash_attention"
+ self.self_attn = (
+ SiglipAttention(config)
+ if not self._use_flash_attention
+ else SiglipFlashAttention(config)
+ )
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+
+ # add recompute
+ # self.self_attn.recompute()
+ # self.layer_norm1.recompute()
+ # self.mlp.recompute()
+ # self.layer_norm2.recompute()
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: ms.Tensor,
+ output_attentions: Optional[ms.Tensor] = False,
+ ) -> Tuple[ms.Tensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class SiglipPreTrainedModel(MSPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SiglipVisionConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+
+ # if isinstance(module, SiglipVisionEmbeddings):
+ # width = self.config.hidden_size
+ # nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+ # elif isinstance(module, nn.Embedding):
+ # default_flax_embed_init(module.weight)
+ # elif isinstance(module, SiglipAttention):
+ # nn.init.normal_(module.q_proj.weight)
+ # nn.init.normal_(module.k_proj.weight)
+ # nn.init.normal_(module.v_proj.weight)
+ # nn.init.normal_(module.out_proj.weight)
+ # nn.init.zeros_(module.q_proj.bias)
+ # nn.init.zeros_(module.k_proj.bias)
+ # nn.init.zeros_(module.v_proj.bias)
+ # nn.init.zeros_(module.out_proj.bias)
+ # elif isinstance(module, SiglipMLP):
+ # nn.init.normal_(module.fc1.weight)
+ # nn.init.normal_(module.fc2.weight)
+ # nn.init.normal_(module.fc1.bias, std=1e-6)
+ # nn.init.normal_(module.fc2.bias, std=1e-6)
+ # elif isinstance(module, (nn.Dense, nn.Conv2d)):
+ # lecun_normal_(module.weight)
+ # if module.bias is not None:
+ # nn.init.zeros_(module.bias)
+ # elif isinstance(module, nn.LayerNorm):
+ # module.bias.data.zero_()
+ # module.weight.data.fill_(1.0)
+ pass
+
+
+SIGLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+ This model is also a PyTorch [torch.nn.Cell](https://pytorch.org/docs/stable/nn.html#torch.nn.Cell) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+ Parameters:
+ config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+SIGLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
+class SiglipEncoder(nn.Cell):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.CellList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # recompute
+ for layer in self.layers:
+ layer.recompute()
+
+ # Ignore copy
+ def construct(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[ms.Tensor] = None,
+ output_attentions: Optional[ms.Tensor] = None,
+ output_hidden_states: Optional[ms.Tensor] = None,
+ return_dict: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+class SiglipVisionTransformer(SiglipPreTrainedModel):
+ config_class = SiglipVisionConfig
+ main_input_name = "pixel_values"
+ _supports_flash_attn_2 = True
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SiglipVisionEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, epsilon=config.layer_norm_eps)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention = config._attn_implementation == "flash_attention"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # recompute
+ # self.encoder.recompute()
+
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.embeddings.patch_embedding
+
+
+ def construct(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[ms.Tensor] = None,
+ tgt_sizes: Optional[ms.Tensor] = None,
+ output_attentions: Optional[ms.Tensor] = None,
+ output_hidden_states: Optional[ms.Tensor] = None,
+ return_dict: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = pixel_values.shape[0]
+ if patch_attention_mask is None:
+ patch_attention_mask = ops.ones(
+ (
+ batch_size,
+ pixel_values.shape[2] // self.config.patch_size,
+ pixel_values.shape[3] // self.config.patch_size,
+ ),
+ dtype=ms.bool_,
+ )
+
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes)
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not ops.any(~patch_attention_mask):
+ attention_mask=None
+ else:
+ attention_mask = (
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
+ if not self._use_flash_attention_2
+ else patch_attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, None) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=None,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/mindone/transformers/models/minicpm_v/processing_minicpmv.py b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
new file mode 100644
index 0000000000..36438fd5f6
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
@@ -0,0 +1,254 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for MiniCPMV.
+"""
+
+import re
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from transformers.utils import TensorType
+from transformers.image_utils import ImageInput
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, nn, ops
+
+from ...processing_utils import ProcessorMixin
+from .image_processing_minicpmv import MiniCPMVBatchFeature, MiniCPMVImageProcessor
+
+class MiniCPMVProcessor(ProcessorMixin):
+ r"""
+ Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
+
+ [`MiniCPMVProcessor`] offers all the functionalities of [`MiniCPMVImageProcessor`] and [`LlamaTokenizerWrapper`]. See the
+ [`~MiniCPMVProcessor.__call__`] and [`~MiniCPMVProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`MiniCPMVImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerWrapper`], *optional*):
+ The tokenizer is a required input.
+ """
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "MiniCPMVImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None):
+ super().__init__(image_processor, tokenizer)
+ self.version = image_processor.version
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ images: ImageInput = None,
+ max_length: Optional[int] = None,
+ do_pad: Optional[bool] = True,
+ max_slice_nums: int = 9,
+ use_image_id: bool = None,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ image_processor=None,
+ **kwargs
+ ) -> MiniCPMVBatchFeature:
+
+ if images is not None:
+ image_inputs = image_processor.preprocess(images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors)
+ return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, image_processor=image_processor, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ output_ids = args[0]
+ result_text = []
+ for result in output_ids:
+ result = result[result != 0]
+ if result[0] == self.tokenizer.bos_id:
+ result = result[1:]
+ if result[-1] == self.tokenizer.eos_id:
+ result = result[:-1]
+ result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
+ return result_text
+ # return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ result = args[0]
+ result = result[result != 0]
+ if result[0] == self.tokenizer.bos_id:
+ result = result[1:]
+ if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
+ result = result[:-1]
+ return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
+
+ def _convert(
+ self, input_str, max_inp_length: Optional[int] = None
+ ):
+ if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
+ input_ids = self.tokenizer.encode(input_str)
+ else:
+ input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str)
+ if max_inp_length is not None:
+ input_ids = input_ids[:max_inp_length]
+ input_ids = ms.Tensor(input_ids, dtype=ms.int32)
+
+ # FIXME ops.where issue
+ start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
+ end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
+ # image_start_tokens = ops.where(start_cond)[0]
+ # image_start_tokens += 1
+ # image_end_tokens = ops.where(end_cond)[0]
+
+ image_start_tokens = []
+ for i in range(len(input_ids)):
+ if input_ids[i] == self.tokenizer.im_start_id or input_ids[i] == self.tokenizer.slice_start_id:
+ image_start_tokens.append(i + 1)
+ image_start_tokens = Tensor(np.array(image_start_tokens))
+ image_end_tokens = []
+ for i in range(len(input_ids)):
+ if input_ids[i] == self.tokenizer.im_end_id or input_ids[i] == self.tokenizer.slice_end_id:
+ image_end_tokens.append(i)
+ image_end_tokens = Tensor(np.array(image_end_tokens))
+
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
+
+ image_bounds = ops.hstack(
+ [
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
+ ]
+ )
+ return input_ids, image_bounds
+
+ def _convert_images_texts_to_inputs(
+ self,
+ images,
+ texts: Union[str, List[str]],
+ truncation=None,
+ max_length=None,
+ max_slice_nums=None,
+ use_image_id=None,
+ return_tensors=None,
+ image_processor=None,
+ **kwargs
+ ):
+ if images is None or not len(images):
+ model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs)
+ return MiniCPMVBatchFeature(data={**model_inputs})
+
+ pattern = "(./)"
+ images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
+
+ if isinstance(texts, str):
+ texts = [texts]
+ input_ids_list = []
+ image_bounds_list = []
+ for index, text in enumerate(texts):
+ image_tags = re.findall(pattern, text)
+ assert len(image_tags) == len(image_sizes[index])
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = final_text + text_chunks[i] + \
+ image_processor.get_slice_image_placeholder(
+ image_sizes[index][i],
+ i,
+ max_slice_nums,
+ use_image_id
+ )
+ final_text += text_chunks[-1]
+ input_ids, image_bounds = self._convert(final_text, max_length)
+ input_ids_list.append(input_ids)
+ image_bounds_list.append(image_bounds)
+ padded_input_ids, padding_lengths = self.pad(
+ input_ids_list,
+ padding_side="left"
+ )
+ for i, length in enumerate(padding_lengths):
+ image_bounds_list[i] = image_bounds_list[i] + length
+ attention_mask = padded_input_ids.ne(0)
+
+ return MiniCPMVBatchFeature(data={
+ "input_ids": padded_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": images,
+ "image_sizes": image_sizes,
+ "image_bound": image_bounds_list,
+ "tgt_sizes": tgt_sizes
+ })
+
+ @property
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = MiniCPMVImageProcessor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
+ items = []
+ if isinstance(inputs[0], list):
+ assert isinstance(inputs[0][0], ms.Tensor)
+ for it in inputs:
+ for tr in it:
+ items.append(tr)
+ else:
+ assert isinstance(inputs[0], ms.Tensor)
+ items = inputs
+
+ batch_size = len(items)
+ shape = items[0].shape
+ dim = len(shape)
+ assert dim <= 2
+ if max_length is None:
+ max_length = 0
+ max_length = max(max_length, max(item.shape[-1] for item in items))
+ min_length = min(item.shape[-1] for item in items)
+ dtype = items[0].dtype
+
+ if dim == 0:
+ return ops.stack([item for item in items], axis=0), [0]
+ elif dim == 1:
+ if max_length == min_length:
+ return ops.stack([item for item in items], axis=0), [0] * batch_size
+ tensor = ops.zeros((batch_size, max_length), dtype=dtype) + padding_value
+ else:
+ tensor = (
+ ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
+ + padding_value
+ )
+
+ padding_length = []
+ for i, item in enumerate(items):
+ if dim == 1:
+ if padding_side == "left":
+ tensor[i, -len(item) :] = item.clone()
+ else:
+ tensor[i, : len(item)] = item.clone()
+ elif dim == 2:
+ if padding_side == "left":
+ tensor[i, -len(item) :, :] = item.clone()
+ else:
+ tensor[i, : len(item), :] = item.clone()
+ padding_length.append(tensor.shape[-1] - len(item))
+
+ return tensor, padding_length
diff --git a/mindone/transformers/models/minicpm_v/resampler.py b/mindone/transformers/models/minicpm_v/resampler.py
new file mode 100644
index 0000000000..0cea2d7e6e
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/resampler.py
@@ -0,0 +1,834 @@
+import math
+import warnings
+from functools import partial
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+import mindspore as ms
+import mindspore.mint.nn.functional as F
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore.common.initializer import One
+from mindspore.common.initializer import TruncatedNormal as trunc_normal_
+from mindspore.common.initializer import XavierNormal as xavier_normal_
+from mindspore.common.initializer import XavierUniform as xavier_uniform_
+from mindspore.common.initializer import Zero, initializer
+from mindspore.mint.nn.functional import *
+from mindspore.nn.layer.activation import *
+
+
+def get_2d_sincos_pos_embed(embed_dim, image_size):
+ """
+ image_size: image_size or (image_height, image_width)
+ return:
+ pos_embed: [image_height, image_width, embed_dim]
+ """
+ if isinstance(image_size, int):
+ grid_h_size, grid_w_size = image_size, image_size
+ else:
+ grid_h_size, grid_w_size = image_size[0], image_size[1]
+
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (H, W)
+ out: (H, W, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
+
+ emb_sin = np.sin(out) # (H, W, D/2)
+ emb_cos = np.cos(out) # (H, W, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
+ return emb
+
+
+class Resampler(nn.Cell):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ given learnable queries and 2d sincos pos_emb
+ Outputs:
+ A tensor with the shape of (batch_size, num_queries, embed_dim)
+ """
+
+ def __init__(
+ self,
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim=None,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
+ adaptive=False,
+ max_size=(70, 70),
+ ):
+ super().__init__()
+ self.num_queries = num_queries
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.adaptive = adaptive
+ self.max_size = max_size
+
+ self.query = Parameter(ops.zeros((self.num_queries, embed_dim)))
+
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = nn.Dense(kv_dim, embed_dim, has_bias=False)
+ else:
+ self.kv_proj = nn.Identity()
+
+ self.attn = MultiheadAttention(embed_dim, num_heads)
+ self.ln_q = norm_layer((embed_dim,))
+ self.ln_kv = norm_layer((embed_dim,))
+
+ self.ln_post = norm_layer((embed_dim,))
+ self.proj = Parameter((embed_dim ** -0.5) * ops.randn(embed_dim, embed_dim))
+
+ self._set_2d_pos_cache(self.max_size)
+
+ def _set_2d_pos_cache(self, max_size):
+ # if is_deepspeed_zero3_enabled():
+ # device='cuda'
+ pos_embed = ms.Tensor(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float()
+ # self.register_buffer("pos_embed", pos_embed, persistent=False)
+ self.pos_embed = pos_embed
+
+ def _adjust_pos_cache(self, tgt_sizes):
+ max_h = ops.max(tgt_sizes[:, 0])[0]
+ max_w = ops.max(tgt_sizes[:, 1])[0]
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
+ self._set_2d_pos_cache(self.max_size)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Dense):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Dense) and m.bias is not None:
+ Zero(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ Zero(m.bias, 0)
+ One(m.weight, 1.0)
+
+ def construct(self, x, tgt_sizes=None):
+ assert x.shape[0] == tgt_sizes.shape[0]
+ bs = x.shape[0]
+
+ dtype = x.dtype
+
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
+
+ self._adjust_pos_cache(tgt_sizes)
+
+ max_patch_len = ops.max(patch_len)[0]
+ key_padding_mask = ops.zeros(Tensor((bs, int(max_patch_len.asnumpy()))), dtype=ms.bool_)
+
+ pos_embed = []
+ for i in range(bs):
+ tgt_h, tgt_w = tgt_sizes[i]
+ shape_0 = tgt_h * tgt_w
+ pos_embed.append(
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)) # patches * D
+ key_padding_mask[i, patch_len[i]:] = True
+
+ # FIXME how to replace torch.nn.utils.rnn.pad_sequence
+ # pos_embed = torch.nn.utils.rnn.pad_sequence(
+ # pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
+ max_length_h = max([i.shape[0] for i in pos_embed])
+ max_length_w = max([i.shape[1] for i in pos_embed])
+ for i in range(len(pos_embed)):
+ if pos_embed[i].shape[0] < max_length_h or pos_embed[i].shape[1] < max_length_w:
+ pos_embed[i] = ops.pad(pos_embed[i], (
+ 0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
+ value=0.0)
+ pos_embed = ops.stack(pos_embed)
+ pos_embed = pos_embed.permute(1, 0, 2)
+
+ x = self.kv_proj(x) # B * L * D
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
+
+ q = self.ln_q(self.query) # Q * D
+
+ out = self.attn(
+ self._repeat(q, bs), # Q * B * D
+ x + pos_embed, # L * B * D + L * B * D
+ x,
+ key_padding_mask=key_padding_mask)[0]
+ # out: Q * B * D
+ x = out.permute(1, 0, 2) # B * Q * D
+
+ x = self.ln_post(x)
+ x = x @ self.proj.astype(x.dtype)
+ return x
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class MultiheadAttention(nn.MultiheadAttention):
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=None):
+ super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first,
+ dtype)
+
+ # rewrite out_proj layer,with nn.Linear
+ self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias)
+
+ def construct(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
+ why_not_fast_path = ''
+ # if ((attn_mask is not None and torch.is_floating_point(attn_mask))
+ # or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
+ # why_not_fast_path = "floating-point masks are not supported for fast path."
+
+ is_batched = query.dim() == 3
+
+ key_padding_mask = _canonical_mask(
+ mask=key_padding_mask,
+ mask_name="key_padding_mask",
+ other_type=_none_or_dtype(attn_mask),
+ other_name="attn_mask",
+ target_type=query.dtype
+ )
+
+ attn_mask = _canonical_mask(
+ mask=attn_mask,
+ mask_name="attn_mask",
+ other_type=None,
+ other_name="",
+ target_type=query.dtype,
+ check_other=False,
+ )
+
+ if not is_batched:
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
+ elif query is not key or key is not value:
+ # When lifting this restriction, don't forget to either
+ # enforce that the dtypes all match or test cases where
+ # they don't!
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ elif self.in_proj_weight is None:
+ why_not_fast_path = "in_proj_weight was None"
+ elif query.dtype != self.in_proj_weight.dtype:
+ # this case will fail anyway, but at least they'll get a useful error message.
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ elif self.training:
+ why_not_fast_path = "training is enabled"
+ elif (self.num_heads % 2) != 0:
+ why_not_fast_path = "self.num_heads is not even"
+ elif not self.batch_first:
+ why_not_fast_path = "batch_first was not True"
+ elif self.bias_k is not None:
+ why_not_fast_path = "self.bias_k was not None"
+ elif self.bias_v is not None:
+ why_not_fast_path = "self.bias_v was not None"
+ elif self.add_zero_attn:
+ why_not_fast_path = "add_zero_attn was enabled"
+ elif not self._qkv_same_embed_dim:
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
+ elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
+ is not supported with NestedTensor input"
+ # elif torch.is_autocast_enabled():
+ # why_not_fast_path = "autocast is enabled"
+
+ if not why_not_fast_path:
+ tensor_args = (
+ query,
+ key,
+ value,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ )
+ # We have to use list comprehensions below because TorchScript does not support
+ # generator expressions.
+ # FIXME logic is passed.
+ # if torch.overrides.has_torch_function(tensor_args):
+ # why_not_fast_path = "some Tensor argument has_torch_function"
+ # elif _is_make_fx_tracing():
+ # why_not_fast_path = "we are running make_fx tracing"
+ # elif not all(_check_arg_device(x) for x in tensor_args):
+ # why_not_fast_path = ("some Tensor argument's device is neither one of "
+ # f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
+ # elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
+ # why_not_fast_path = ("grad is enabled and at least one of query or the "
+ # "input/output projection weights or biases requires_grad")
+ if not why_not_fast_path:
+ merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
+
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
+ return ms.nn.MultiheadAttention(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ merged_mask,
+ need_weights,
+ average_attn_weights,
+ mask_type)
+
+ # any_nested = query.is_nested or key.is_nested or value.is_nested
+ # assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
+ # f"The fast path was not hit because {why_not_fast_path}")
+
+ if self.batch_first and is_batched:
+ # make sure that the transpose op does not affect the "is" property
+ if key is value:
+ if query is key:
+ query = key = value = query.transpose(1, 0)
+ else:
+ query, key = (x.transpose(1, 0) for x in (query, key))
+ value = key
+ else:
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
+
+ if not self._qkv_same_embed_dim:
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight,
+ average_attn_weights=average_attn_weights,
+ is_causal=is_causal)
+ else:
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ average_attn_weights=average_attn_weights,
+ is_causal=is_causal)
+ if self.batch_first and is_batched:
+ return attn_output.transpose(1, 0), attn_output_weights
+ else:
+ return attn_output, attn_output_weights
+
+ def multi_head_attention_forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[Tensor],
+ in_proj_bias: Optional[Tensor],
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Optional[Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
+ # FIXME: logic passed
+ # if has_torch_function(tens_ops):
+ # return handle_torch_function(
+ # multi_head_attention_forward,
+ # tens_ops,
+ # query,
+ # key,
+ # value,
+ # embed_dim_to_check,
+ # num_heads,
+ # in_proj_weight,
+ # in_proj_bias,
+ # bias_k,
+ # bias_v,
+ # add_zero_attn,
+ # dropout_p,
+ # out_proj_weight,
+ # out_proj_bias,
+ # training=training,
+ # key_padding_mask=key_padding_mask,
+ # need_weights=need_weights,
+ # attn_mask=attn_mask,
+ # is_causal=is_causal,
+ # use_separate_proj_weight=use_separate_proj_weight,
+ # q_proj_weight=q_proj_weight,
+ # k_proj_weight=k_proj_weight,
+ # v_proj_weight=v_proj_weight,
+ # static_k=static_k,
+ # static_v=static_v,
+ # average_attn_weights=average_attn_weights,
+ # )
+
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
+
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
+ # is batched, run the computation and before returning squeeze the
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
+ if not is_batched:
+ # unsqueeze if the input is unbatched
+ query = query.unsqueeze(1)
+ key = key.unsqueeze(1)
+ value = value.unsqueeze(1)
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.unsqueeze(0)
+
+ # set up shape vars
+ tgt_len, bsz, embed_dim = query.shape
+ src_len, _, _ = key.shape
+
+ key_padding_mask = _canonical_mask(
+ mask=key_padding_mask,
+ mask_name="key_padding_mask",
+ other_type=_none_or_dtype(attn_mask),
+ other_name="attn_mask",
+ target_type=query.dtype
+ )
+
+ if is_causal and attn_mask is None:
+ raise RuntimeError(
+ "Need attn_mask if specifying the is_causal hint. "
+ "You may use the Transformer module method "
+ "`generate_square_subsequent_mask` to create this mask."
+ )
+
+ if is_causal and key_padding_mask is None and not need_weights:
+ # when we have a kpm or need weights, we need attn_mask
+ # Otherwise, we use the is_causal hint go as is_causal
+ # indicator to SDPA.
+ attn_mask = None
+ else:
+ attn_mask = _canonical_mask(
+ mask=attn_mask,
+ mask_name="attn_mask",
+ other_type=None,
+ other_name="",
+ target_type=query.dtype,
+ check_other=False,
+ )
+
+ if key_padding_mask is not None:
+ # We have the attn_mask, and use that to merge kpm into it.
+ # Turn off use of is_causal hint, as the merged mask is no
+ # longer causal.
+ is_causal = False
+
+ assert embed_dim == embed_dim_to_check, \
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ if isinstance(embed_dim, ms.Tensor):
+ # embed_dim can be a tensor when JIT tracing
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ else:
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
+ if use_separate_proj_weight:
+ # allow MHA to have different embedding dimensions when separate projection weights are used
+ assert key.shape[:2] == value.shape[:2], \
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ else:
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
+
+ #
+ # compute in-projection
+ #
+ if not use_separate_proj_weight:
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
+ else:
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
+ if in_proj_bias is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
+
+ # prep attention mask
+
+ if attn_mask is not None:
+ # ensure attn_mask's dim is 3
+ if attn_mask.dim() == 2:
+ correct_2d_size = (tgt_len, src_len)
+ if attn_mask.shape != correct_2d_size:
+ raise RuntimeError(
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ attn_mask = attn_mask.unsqueeze(0)
+ elif attn_mask.dim() == 3:
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
+ if attn_mask.shape != correct_3d_size:
+ raise RuntimeError(
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ else:
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
+
+ # add bias along batch dimension (currently second)
+ if bias_k is not None and bias_v is not None:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ k = ops.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = ops.cat([v, bias_v.repeat(1, bsz, 1)])
+
+ # FIXME where is pad?????????
+ if attn_mask is not None:
+ attn_mask = ops.pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = ops.pad(key_padding_mask, (0, 1))
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ #
+ # reshape q, k, v for multihead attention and make em batch first
+ #
+ q = q.view(tgt_len, bsz * num_heads, head_dim).permute(1, 0, 2)
+ if static_k is None:
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_k.shape[0] == bsz * num_heads, \
+ f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
+ assert static_k.shape[2] == head_dim, \
+ f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
+ k = static_k
+ if static_v is None:
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_v.shape[0] == bsz * num_heads, \
+ f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
+ assert static_v.shape[2] == head_dim, \
+ f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
+ v = static_v
+
+ # add zero attention along batch dimension (now first)
+ if add_zero_attn:
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
+ k = ops.cat([k, ops.zeros(zero_attn_shape, dtype=k.dtype)], axis=1)
+ v = ops.cat([v, ops.zeros(zero_attn_shape, dtype=v.dtype)], axis=1)
+ if attn_mask is not None:
+ attn_mask = ops.pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = ops.pad(key_padding_mask, (0, 1))
+
+ # update source sequence length after adjustments
+ src_len = k.shape[1]
+
+ # merge key padding and attention masks
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (bsz, src_len), \
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
+ if attn_mask is None:
+ attn_mask = key_padding_mask
+ else:
+ attn_mask = attn_mask + key_padding_mask
+
+ # adjust dropout probability
+ if not training:
+ dropout_p = 0.0
+
+ #
+ # (deep breath) calculate attention and out projection
+ #
+
+ if need_weights:
+ B, Nt, E = q.shape
+ q_scaled = q / math.sqrt(E)
+
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
+
+ if attn_mask is not None:
+ attn_output_weights = ops.baddbmm(attn_mask, q_scaled, k.permute(0, 2, 1))
+ else:
+ attn_output_weights = ops.bmm(q_scaled, k.permute(0, 2, 1))
+ attn_output_weights = ops.softmax(attn_output_weights, axis=-1)
+ if dropout_p > 0.0:
+ attn_output_weights = ops.dropout(attn_output_weights, p=dropout_p)
+
+ attn_output = ops.bmm(attn_output_weights, v)
+
+ attn_output = attn_output.permute(1, 0, 2).view(tgt_len * bsz, embed_dim)
+ attn_output = self.out_proj(attn_output)
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1])
+
+ # optionally average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ if average_attn_weights:
+ attn_output_weights = attn_output_weights.mean(axis=1)
+
+ if not is_batched:
+ # squeeze the output if input was unbatched
+ attn_output = attn_output.squeeze(1)
+ attn_output_weights = attn_output_weights.squeeze(0)
+ return attn_output, attn_output_weights
+ else:
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
+ # in order to match the input for SDPA of (N, num_heads, L, S)
+ if attn_mask is not None:
+ if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
+ attn_mask = attn_mask.unsqueeze(0)
+ else:
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
+
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
+ k = k.view(bsz, num_heads, src_len, head_dim)
+ v = v.view(bsz, num_heads, src_len, head_dim)
+
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
+ attn_output = attn_output.permute(2, 0, 1, 3).view(bsz * tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1])
+ if not is_batched:
+ # squeeze the output if input was unbatched
+ attn_output = attn_output.squeeze(1)
+ return attn_output, None
+
+
+def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
+ key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
+ # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
+ # and returns if the input is batched or not.
+ # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
+
+ # Shape check.
+ if query.dim() == 3:
+ # Batched Inputs
+ is_batched = True
+ assert key.dim() == 3 and value.dim() == 3, \
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 2, \
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ elif query.dim() == 2:
+ # Unbatched Inputs
+ is_batched = False
+ assert key.dim() == 2 and value.dim() == 2, \
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 1, \
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ if attn_mask.dim() == 3:
+ expected_shape = (num_heads, query.shape[0], key.shape[0])
+ assert attn_mask.shape == expected_shape, \
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
+ else:
+ raise AssertionError(
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
+
+ return is_batched
+
+
+def _canonical_mask(
+ mask: Optional[Tensor],
+ mask_name: str,
+ other_type: Optional,
+ other_name: str,
+ target_type: None,
+ check_other: bool = True,
+) -> Optional[Tensor]:
+ if mask is not None:
+ _mask_dtype = mask.dtype
+ _mask_is_float = ops.is_floating_point(mask)
+ if _mask_dtype != ms.bool_ and not _mask_is_float:
+ raise AssertionError(
+ f"only bool and floating types of {mask_name} are supported")
+ if check_other and other_type is not None:
+ if _mask_dtype != other_type:
+ warnings.warn(
+ f"Support for mismatched {mask_name} and {other_name} "
+ "is deprecated. Use same type for both instead."
+ )
+ if not _mask_is_float:
+ mask = (
+ ops.zeros_like(mask, dtype=target_type)
+ .masked_fill(mask, float("-inf"))
+ )
+ return mask
+
+
+def _none_or_dtype(input: Optional[Tensor]) -> Optional:
+ if input is None:
+ return None
+ elif isinstance(input, ms.Tensor):
+ return input.dtype
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
+
+
+def _in_projection_packed(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w: Tensor,
+ b: Optional[Tensor] = None,
+) -> List[Tensor]:
+ r"""
+ Performs the in-projection step of the attention operation, using packed weights.
+ Output is a triple containing projection tensors for query, key and value.
+ Args:
+ q, k, v: query, key and value tensors to be projected. For self-attention,
+ these are typically the same tensor; for encoder-decoder attention,
+ k and v are typically the same tensor. (We take advantage of these
+ identities for performance if they are present.) Regardless, q, k and v
+ must share a common embedding dimension; otherwise their shapes may vary.
+ w: projection weights for q, k and v, packed into a single tensor. Weights
+ are packed along dimension 0, in q, k, v order.
+ b: optional projection biases for q, k and v, packed into a single tensor
+ in q, k, v order.
+ Shape:
+ Inputs:
+ - q: :math:`(..., E)` where E is the embedding dimension
+ - k: :math:`(..., E)` where E is the embedding dimension
+ - v: :math:`(..., E)` where E is the embedding dimension
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
+ - b: :math:`E * 3` where E is the embedding dimension
+ Output:
+ - in output list :math:`[q', k', v']`, each output tensor will have the
+ same shape as the corresponding input tensor.
+ """
+ E = q.shape[-1]
+ if k is v:
+ if q is k:
+ # self-attention
+ proj = ops.dense(q, w, b)
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ return proj[0], proj[1], proj[2]
+ else:
+ # encoder-decoder attention
+ w_q, w_kv = w.split([E, E * 2])
+ if b is None:
+ b_q = b_kv = None
+ else:
+ b_q, b_kv = b.split([E, E * 2])
+ q_proj = ops.dense(q, w_q, b_q)
+ kv_proj = ops.dense(k, w_kv, b_kv)
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ return (q_proj, kv_proj[0], kv_proj[1])
+ else:
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return ops.dense(q, w_q, b_q), ops.dense(k, w_k, b_k), ops.dense(v, w_v, b_v)
+
+
+def _in_projection(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w_q: Tensor,
+ w_k: Tensor,
+ w_v: Tensor,
+ b_q: Optional[Tensor] = None,
+ b_k: Optional[Tensor] = None,
+ b_v: Optional[Tensor] = None,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ r"""
+ Performs the in-projection step of the attention operation. This is simply
+ a triple of linear projections, with shape constraints on the weights which
+ ensure embedding dimension uniformity in the projected outputs.
+ Output is a triple containing projection tensors for query, key and value.
+ Args:
+ q, k, v: query, key and value tensors to be projected.
+ w_q, w_k, w_v: weights for q, k and v, respectively.
+ b_q, b_k, b_v: optional biases for q, k and v, respectively.
+ Shape:
+ Inputs:
+ - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
+ number of leading dimensions.
+ - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
+ number of leading dimensions.
+ - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
+ number of leading dimensions.
+ - w_q: :math:`(Eq, Eq)`
+ - w_k: :math:`(Eq, Ek)`
+ - w_v: :math:`(Eq, Ev)`
+ - b_q: :math:`(Eq)`
+ - b_k: :math:`(Eq)`
+ - b_v: :math:`(Eq)`
+ Output: in output triple :math:`(q', k', v')`,
+ - q': :math:`[Qdims..., Eq]`
+ - k': :math:`[Kdims..., Eq]`
+ - v': :math:`[Vdims..., Eq]`
+ """
+ Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
+ assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
+ assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
+ assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
+ assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
+ assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
+ return ops.dense(q, w_q, b_q), ops.dense(k, w_k, b_k), ops.dense(v, w_v, b_v)
diff --git a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
new file mode 100644
index 0000000000..1707f105fb
--- /dev/null
+++ b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
@@ -0,0 +1,70 @@
+from mindnlp.transformers import AutoTokenizer
+
+from ..qwen2 import Qwen2TokenizerFast
+
+
+class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.im_start = ""
+ self.im_end = ""
+ self.ref_start = "["
+ self.ref_end = "]"
+ self.box_start = ""
+ self.box_end = ""
+ self.quad_start = ""
+ self.quad_end = ""
+ self.slice_start = ""
+ self.slice_end = ""
+ self.im_id_start = ""
+ self.im_id_end = ""
+
+ @property
+ def eos_id(self):
+ return self.eos_token_id
+
+ @property
+ def bos_id(self):
+ return self.bos_token_id
+
+ @property
+ def unk_id(self):
+ return self.unk_token_id
+
+ @property
+ def im_start_id(self):
+ return self.convert_tokens_to_ids(self.im_start)
+
+ @property
+ def im_end_id(self):
+ return self.convert_tokens_to_ids(self.im_end)
+
+ @property
+ def slice_start_id(self):
+ return self.convert_tokens_to_ids(self.slice_start)
+
+ @property
+ def slice_end_id(self):
+ return self.convert_tokens_to_ids(self.slice_end)
+
+ @property
+ def im_id_start_id(self):
+ return self.convert_tokens_to_ids(self.im_id_start)
+
+ @property
+ def im_id_end_id(self):
+ return self.convert_tokens_to_ids(self.im_id_end)
+
+ @property
+ def newline_id(self):
+ return self.convert_tokens_to_ids('\n')
+
+ @staticmethod
+ def escape(text: str) -> str:
+ return text
+
+ @staticmethod
+ def unescape(text: str) -> str:
+ return text
+
+AutoTokenizer.register("MiniCPMVTokenizerFast", MiniCPMVTokenizerFast)
diff --git a/mindone/transformers/models/qwen2/__init__.py b/mindone/transformers/models/qwen2/__init__.py
new file mode 100644
index 0000000000..be2d5916fd
--- /dev/null
+++ b/mindone/transformers/models/qwen2/__init__.py
@@ -0,0 +1,52 @@
+# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from transformers.utils import _LazyModule
+
+_import_structure = {
+ "configuration_qwen2": ["Qwen2Config"],
+ "tokenization_qwen2": ["Qwen2Tokenizer"],
+}
+
+
+
+_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
+
+_import_structure["modeling_qwen2"] = [
+ "Qwen2ForCausalLM",
+ "Qwen2Model",
+ "Qwen2PreTrainedModel",
+ "Qwen2ForSequenceClassification",
+ "Qwen2ForTokenClassification",
+]
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen2 import Qwen2Config
+ from .modeling_qwen2 import (
+ Qwen2ForCausalLM,
+ Qwen2ForSequenceClassification,
+ Qwen2ForTokenClassification,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ )
+ from .tokenization_qwen2 import Qwen2Tokenizer
+ from .tokenization_qwen2_fast import Qwen2TokenizerFast
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/mindone/transformers/models/qwen2/configuration_qwen2.py b/mindone/transformers/models/qwen2/configuration_qwen2.py
new file mode 100644
index 0000000000..10464f6bfc
--- /dev/null
+++ b/mindone/transformers/models/qwen2/configuration_qwen2.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Qwen2 model configuration"""
+
+from transformers import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 22016):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 32):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
+ Whether to use sliding window attention.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+ max_window_layers (`int`, *optional*, defaults to 28):
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import Qwen2Model, Qwen2Config
+
+ >>> # Initializing a Qwen2 style configuration
+ >>> configuration = Qwen2Config()
+
+ >>> # Initializing a model from the Qwen2-7B style configuration
+ >>> model = Qwen2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window if use_sliding_window else None
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
new file mode 100644
index 0000000000..b813e208a4
--- /dev/null
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -0,0 +1,1432 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Mindspore Qwen2 model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+import mindspore as ms
+from mindspore import nn, ops, Tensor, Parameter
+from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import MSPreTrainedModel
+from .configuration_qwen2 import Qwen2Config
+
+from transformers import logging
+
+from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+
+_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
+_CONFIG_FOR_DOC = "Qwen2Config"
+
+_MIN_FP16 = ms.tensor(np.finfo(np.float16).min, dtype=ms.float16)
+_MIN_FP32 = ms.tensor(np.finfo(np.float32).min, dtype=ms.float32)
+_MIN_FP64 = ms.tensor(np.finfo(np.float64).min, dtype=ms.float64)
+_MIN_BF16 = ms.tensor(float.fromhex("-0x1.fe00000000000p+127"), dtype=ms.bfloat16)
+
+
+def dtype_to_min(dtype):
+ if dtype == ms.float16:
+ return _MIN_FP16
+ if dtype == ms.float32:
+ return _MIN_FP32
+ if dtype == ms.float64:
+ return _MIN_FP64
+ if dtype == ms.bfloat16:
+ return _MIN_BF16
+ else:
+ raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}")
+
+
+
+# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
+def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: ms.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: ms.dtype,
+ device: None,
+ min_dtype: float,
+ cache_position: ms.Tensor,
+ batch_size: int,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`ms.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`ms.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`ms.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = ops.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype)
+ if sequence_length != 1:
+ causal_mask = ops.triu(causal_mask, diagonal=1)
+ causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ # causal_mask = causal_mask # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ # causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ # padding_mask, min_dtype
+ # )
+ if mask_length >= causal_mask.shape[-1]:
+ causal_mask = causal_mask.masked_fill(padding_mask, min_dtype)
+ else:
+ causal_mask = ops.cat(
+ [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
+ ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)],
+ axis=-1
+ )
+
+ return causal_mask
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
+class Qwen2RMSNorm(nn.Cell):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = Parameter(ops.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def construct(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(ms.float32)
+ variance = hidden_states.pow(2).mean(-1, keep_dims=True)
+ hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
+class Qwen2RotaryEmbedding(nn.Cell):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2, dtype=ms.int64).float() / self.dim))
+ self.inv_freq = inv_freq
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=None, dtype=ms.float32
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = ops.arange(self.max_seq_len_cached, dtype=ms.int64).type_as(self.inv_freq)
+
+ freqs = ops.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = ops.cat((freqs, freqs), axis=-1)
+ self.cos_cached = emb.cos().to(dtype)
+ self.sin_cached = emb.sin().to(dtype)
+
+ def construct(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=None, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return ops.cat((-x2, x1), axis=-1)
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`ms.Tensor`): The query tensor.
+ k (`ms.Tensor`): The key tensor.
+ cos (`ms.Tensor`): The cosine part of the rotary embedding.
+ sin (`ms.Tensor`): The sine part of the rotary embedding.
+ position_ids (`ms.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
+class Qwen2MLP(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False)
+ self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False)
+ self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def construct(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Qwen2Attention(nn.Cell):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=True)
+ self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=True)
+ self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=True)
+ self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False)
+
+ self.rotary_emb = Qwen2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ self.scale = self.head_dim ** -0.5
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ query_states = ops.mul(query_states, self.scale ** 0.5)
+ key_states = ops.mul(key_states, self.scale ** 0.5)
+
+ attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
+
+ if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.shape}"
+ )
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = ops.softmax(attn_weights, axis=-1, dtype=ms.float32).to(query_states.dtype)
+ attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = ops.matmul(attn_weights, value_states)
+
+ if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.shape}"
+ )
+
+ attn_output = attn_output.swapaxes(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Qwen2FlashAttention2(Qwen2Attention):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = False
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[ms.Tensor] = None,
+ ):
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = (
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
+ )
+
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :]
+ past_value = past_value[:, :, slicing_tokens:, :]
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = ops.cat([attention_mask, ops.ones_like(attention_mask[:, -1:])], axis=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == ms.float32:
+ # if torch.is_autocast_enabled():
+ # target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.swapaxes(1, 2)
+ key_states = key_states.swapaxes(1, 2)
+ value_states = value_states.swapaxes(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
+class Qwen2SdpaAttention(Qwen2Attention):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ # if query_states.device.type == "cuda" and attention_mask is not None:
+ # query_states = query_states.contiguous()
+ # key_states = key_states.contiguous()
+ # value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = ops.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.swapaxes(1, 2)
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+QWEN2_ATTENTION_CLASSES = {
+ "eager": Qwen2Attention,
+ "flash_attention_2": Qwen2FlashAttention2,
+ "sdpa": Qwen2SdpaAttention,
+}
+
+
+class Qwen2DecoderLayer(nn.Cell):
+ def __init__(self, config: Qwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ # recompute
+ # self.self_attn.recompute()
+ # self.mlp.recompute()
+ # self.input_layernorm.recompute()
+ # self.post_attention_layernorm.recompute()
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Tuple[ms.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ) -> Tuple[ms.Tensor, Optional[Tuple[ms.Tensor, ms.Tensor]]]:
+ """
+ Args:
+ hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`ms.Tensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(ms.Tensor)`, *optional*): cached past key and value projection states
+ cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+QWEN2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Cell](https://pytorch.org/docs/stable/nn.html#torch.nn.Cell) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+class Qwen2PreTrainedModel(MSPreTrainedModel):
+ config_class = Qwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ # std = self.config.initializer_range
+ # if isinstance(module, nn.Dense):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.bias is not None:
+ # module.bias.data.zero_()
+ # elif isinstance(module, nn.Embedding):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.padding_idx is not None:
+ # module.weight.data[module.padding_idx].zero_()
+ pass
+
+
+QWEN2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(ms.Tensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+class Qwen2Model(Qwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: Qwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
+ self.layers = nn.CellList(
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # recompute
+ # self.layers.recompute()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[List[ms.Tensor]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ use_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
+ use_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = ops.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: ms.Tensor,
+ input_tensor: ms.Tensor,
+ cache_position: ms.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, None
+ min_dtype = dtype_to_min(dtype)
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, ms.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+class Qwen2ForCausalLM(Qwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Dense(config.hidden_size, config.vocab_size, has_bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[List[ms.Tensor]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ # FIXME generation scling method
+ if past_key_values is not None:
+ # if inputs_embeds is not None and input_ids is None: # Exception 1
+ # # input_ids = input_ids[:, -cache_position.shape[0]:]
+ # input_ids = input_ids
+ # elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ # input_ids = input_ids[:, :cache_position.shape[0]]
+ if inputs_embeds is not None: # Exception 1
+ if 0 not in input_ids.shape:
+ input_ids = input_ids[:, -cache_position.shape[0]:]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = ops.index_select(input_ids, -1, cache_position)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if inputs_embeds is not None:
+ batch_size, sequence_length = inputs_embeds.shape
+ device = None
+ else:
+ batch_size, sequence_length = input_ids.shape
+ device = None
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = dtype_to_min(dtype)
+
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+
+
+class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = Qwen2Model(config)
+ self.score = nn.Dense(config.hidden_size, self.num_labels, has_bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[List[ms.Tensor]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`ms.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = ops.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[ops.arange(batch_size), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == ms.float64 or labels.dtype == ms.int32):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
+class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = Qwen2Model(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Dense(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[List[ms.Tensor]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`ms.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2.py b/mindone/transformers/models/qwen2/tokenization_qwen2.py
new file mode 100644
index 0000000000..c5cff300a2
--- /dev/null
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2.py
@@ -0,0 +1,337 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Qwen2."""
+
+import json
+import os
+import unicodedata
+from functools import lru_cache
+from typing import Optional, Tuple
+
+import regex as re
+from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
+
+PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+
+
+@lru_cache()
+# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class Qwen2Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import Qwen2Tokenizer
+
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
+ >>> tokenizer("Hello world")["input_ids"]
+ [9707, 1879]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [21927, 1879]
+ ```
+ This is expected.
+
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*):
+ The beginning of sequence token. Not applicable for this tokenizer.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token=None,
+ eos_token="<|endoftext|>",
+ pad_token="<|endoftext|>",
+ clean_up_tokenization_spaces=False,
+ split_special_tokens=False,
+ **kwargs,
+ ):
+ # Qwen vocab does not contain control tokens; added tokens need to be special
+ bos_token = (
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(bos_token, str)
+ else bos_token
+ )
+ eos_token = (
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(eos_token, str)
+ else eos_token
+ )
+ unk_token = (
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(unk_token, str)
+ else unk_token
+ )
+ pad_token = (
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(pad_token, str)
+ else pad_token
+ )
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ bpe_merges = []
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ for i, line in enumerate(merges_handle):
+ line = line.strip()
+ if (i == 0 and line.startswith("#version:")) or not line:
+ continue
+ bpe_merges.append(tuple(line.split()))
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ # NOTE: the cache can grow without bound and will get really large for long running processes
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
+ # not a memory leak but appears as one.
+ # GPT2Tokenizer has the same problem, so let's be consistent.
+ self.cache = {}
+
+ self.pat = re.compile(PRETOKENIZE_REGEX)
+
+ if kwargs.get("add_prefix_space", False):
+ logger.warning_once(
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
+ )
+
+ super().__init__(
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ split_special_tokens=split_special_tokens,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.encoder)
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def decode(
+ self,
+ token_ids,
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = False,
+ spaces_between_special_tokens: bool = False,
+ **kwargs,
+ ) -> str:
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
+ return super().decode(
+ token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ spaces_between_special_tokens=spaces_between_special_tokens,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def prepare_for_tokenization(self, text, **kwargs):
+ text = unicodedata.normalize("NFC", text)
+ return (text, kwargs)
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py b/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
new file mode 100644
index 0000000000..5ed3c74bc4
--- /dev/null
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
@@ -0,0 +1,134 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Qwen2."""
+
+from typing import Optional, Tuple
+
+from mindnlp.transformers.tokenization_utils import AddedToken
+from mindnlp.transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.utils import logging
+
+from .tokenization_qwen2 import Qwen2Tokenizer
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+ "tokenizer_file": "tokenizer.json",
+}
+
+
+MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
+
+
+class Qwen2TokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import Qwen2TokenizerFast
+
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
+ >>> tokenizer("Hello world")["input_ids"]
+ [9707, 1879]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [21927, 1879]
+ ```
+ This is expected.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ Path to the vocabulary file.
+ merges_file (`str`, *optional*):
+ Path to the merges file.
+ tokenizer_file (`str`, *optional*):
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead. Not applicable to this tokenizer.
+ bos_token (`str`, *optional*):
+ The beginning of sequence token. Not applicable for this tokenizer.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = Qwen2Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token=None,
+ eos_token="<|endoftext|>",
+ pad_token="<|endoftext|>",
+ **kwargs,
+ ):
+ # We need to at least pass vocab_file and merges_file to base class
+ # in case a slow tokenizer needs to be initialized; other can be
+ # configured through files.
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
+
+ bos_token = (
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(bos_token, str)
+ else bos_token
+ )
+ eos_token = (
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(eos_token, str)
+ else eos_token
+ )
+ unk_token = (
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(unk_token, str)
+ else unk_token
+ )
+ pad_token = (
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
+ if isinstance(pad_token, str)
+ else pad_token
+ )
+
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py
new file mode 100644
index 0000000000..c258ae580d
--- /dev/null
+++ b/mindone/transformers/processing_utils.py
@@ -0,0 +1,295 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+ Processing saving/loading class for common processors.
+"""
+
+import os
+import warnings
+from typing import Optional, Union
+
+import transformers
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
+transformers_module = transformers
+
+AUTO_TO_BASE_CLASS_MAPPING = {
+ "AutoTokenizer": "PreTrainedTokenizerBase",
+ "AutoFeatureExtractor": "FeatureExtractionMixin",
+ "AutoImageProcessor": "ImageProcessingMixin",
+}
+
+class ProcessorMixin:
+ """
+ This is a mixin used to provide saving/loading functionality for all processor classes.
+ """
+ attributes = ["feature_extractor", "tokenizer"]
+ # Names need to be attr_class for attr in attributes
+ feature_extractor_class = None
+ tokenizer_class = None
+ _auto_class = None
+
+ # args have to match the attributes class attribute
+ def __init__(self, *args, **kwargs):
+ """
+ This method initializes an instance of the ProcessorMixin class.
+
+ Args:
+ self (ProcessorMixin): The instance of the ProcessorMixin class.
+
+ Returns:
+ None.
+
+ Raises:
+ TypeError: Raised if an unexpected keyword argument is provided,
+ if multiple values are provided for a single argument,
+ or if the arguments provided do not match the required attributes of the processor.
+ ValueError: Raised if the number of arguments provided does not match
+ the required number of attributes for the processor,
+ or if the type of the argument does not match the expected class type.
+ """
+ # Sanitize args and kwargs
+ for key in kwargs:
+ if key not in self.attributes:
+ raise TypeError(f"Unexpected keyword argument {key}.")
+ for arg, attribute_name in zip(args, self.attributes):
+ if attribute_name in kwargs:
+ raise TypeError(f"Got multiple values for argument {attribute_name}.")
+ kwargs[attribute_name] = arg
+
+ if len(kwargs) != len(self.attributes):
+ raise ValueError(
+ f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got "
+ f"{len(args)} arguments instead."
+ )
+
+ # Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
+ for attribute_name, arg in kwargs.items():
+ class_name = getattr(self, f"{attribute_name}_class")
+ # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
+ class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
+ if isinstance(class_name, tuple):
+ proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
+ elif class_name == "MiniCPMVImageProcessor":
+ from mindone.transformers import MiniCPMVImageProcessor
+ proper_class = MiniCPMVImageProcessor
+ else:
+ proper_class = getattr(transformers_module, class_name)
+
+ if not isinstance(arg, proper_class):
+ raise ValueError(
+ f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
+ )
+
+ setattr(self, attribute_name, arg)
+
+ def __repr__(self):
+ """
+ Method '__repr__' in the class 'ProcessorMixin' generates a string representation of the object.
+
+ Args:
+ self: ProcessorMixin instance. Represents the object for which the string representation is being generated.
+
+ Returns:
+ str:
+ A formatted string representation of the object containing its class name and attributes.
+ Returns None if there are no attributes to represent.
+
+ Raises:
+ None.
+ """
+ attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes]
+ attributes_repr = "\n".join(attributes_repr)
+ return f"{self.__class__.__name__}:\n{attributes_repr}"
+
+ def save_pretrained(self, save_directory, **kwargs):
+ """
+ Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
+ can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
+
+
+
+ This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
+ methods above for more information.
+
+
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
+ be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ for attribute_name in self.attributes:
+ attribute = getattr(self, attribute_name)
+ # Include the processor class in the attribute config so this processor can then be reloaded with the
+ # `AutoProcessor` API.
+ if hasattr(attribute, "_set_processor_class"):
+ attribute._set_processor_class(self.__class__.__name__)
+ attribute.save_pretrained(save_directory)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a processor associated with a pretrained model.
+
+
+
+ This class method is simply calling the feature extractor
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor
+ [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the
+ methods above for more information.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ hf-mirror.com. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
+ namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ **kwargs
+ Additional keyword arguments passed along to both
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
+ """
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
+ return cls(*args)
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+ in the library are already mapped with `AutoProcessor`.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
+ The auto class to register this new feature extractor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ if not hasattr(transformers.models.auto, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ @classmethod
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ """
+ A method to obtain arguments from a pre-trained model or path.
+
+ Args:
+ cls (class): The class object.
+ pretrained_model_name_or_path (str): The name or path of the pre-trained model.
+
+ Returns:
+ None.
+
+ Raises:
+ None.
+ """
+ args = []
+ for attribute_name in cls.attributes:
+ class_name = getattr(cls, f"{attribute_name}_class")
+ if isinstance(class_name, tuple):
+ classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
+ use_fast = kwargs.get("use_fast", True)
+ if use_fast and classes[1] is not None:
+ attribute_class = classes[1]
+ else:
+ attribute_class = classes[0]
+ elif class_name == "MiniCPMVImageProcessor":
+ from mindone.transformers import MiniCPMVImageProcessor
+ attribute_class = MiniCPMVImageProcessor
+ else:
+ attribute_class = getattr(transformers_module, class_name)
+
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
+ return args
+
+ @property
+ def model_input_names(self):
+ """
+ Retrieve the model input names from the first attribute of the ProcessorMixin instance.
+
+ Args:
+ self (ProcessorMixin): The instance of the ProcessorMixin class.
+
+ Returns:
+ None: Returns the model input names from the first attribute of the ProcessorMixin instance if available,
+ otherwise returns None.
+
+ Raises:
+ None.
+ """
+ first_attribute = getattr(self, self.attributes[0])
+ return getattr(first_attribute, "model_input_names", None)
diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py
new file mode 100644
index 0000000000..dd5f2c6f78
--- /dev/null
+++ b/mindone/transformers/utils/generic.py
@@ -0,0 +1,58 @@
+import inspect
+
+from enum import Enum
+
+def can_return_loss(model_class):
+ """
+ Check if a given model can return loss.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ signature = inspect.signature(model_class.construct) # MindSpore models
+
+ for p in signature.parameters:
+ if p == "return_loss" and signature.parameters[p].default is True:
+ return True
+
+ return False
+
+
+def find_labels(model_class):
+ """
+ Find the labels used by a given model.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ model_name = model_class.__name__
+ signature = inspect.signature(model_class.construct) # MindSpore models
+
+ if "QuestionAnswering" in model_name:
+ return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
+ else:
+ return [p for p in signature.parameters if "label" in p]
+
+class ExplicitEnum(str, Enum):
+ """
+ Enum with more explicit error message for missing values.
+ """
+
+ @classmethod
+ def _missing_(cls, value):
+ raise ValueError(
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
+ )
+
+class TensorType(ExplicitEnum):
+ """
+ Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
+ tab-completion in an IDE.
+ """
+
+ PYTORCH = "pt"
+ TENSORFLOW = "tf"
+ NUMPY = "np"
+ JAX = "jax"
+ MLX = "mlx"
+ MINDSPORE = "ms"
\ No newline at end of file
From d1c90cb48ff7cb49e0a2b445fd9e82de7a69210f Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Tue, 19 Nov 2024 20:23:13 +0800
Subject: [PATCH 2/9] feat(minicpm-v): Support MiniCPM-V inference pipeline
---
examples/minicpm_v/inference/inference.py | 17 +-
mindone/transformers/__init__.py | 3 +-
.../transformers/feature_extraction_utils.py | 18 +-
.../transformers/image_processing_utils.py | 13 +-
mindone/transformers/image_transforms.py | 8 +-
mindone/transformers/image_utils.py | 14 +-
.../transformers/models/bert/modeling_bert.py | 28 +-
.../transformers/models/minicpm_v/__init__.py | 2 +-
.../models/minicpm_v/configuration_minicpm.py | 3 +-
.../minicpm_v/image_processing_minicpmv.py | 79 ++--
.../models/minicpm_v/modeling_minicpmv.py | 150 ++++----
.../models/minicpm_v/modeling_navit_siglip.py | 56 ++-
.../models/minicpm_v/processing_minicpmv.py | 99 ++---
.../models/minicpm_v/resampler.py | 355 +++++++++++-------
.../minicpm_v/tokenization_minicpmv_fast.py | 3 +-
mindone/transformers/models/qwen2/__init__.py | 1 -
.../models/qwen2/modeling_qwen2.py | 48 +--
.../models/qwen2/tokenization_qwen2.py | 4 +-
mindone/transformers/processing_utils.py | 20 +-
mindone/transformers/utils/generic.py | 6 +-
20 files changed, 500 insertions(+), 427 deletions(-)
diff --git a/examples/minicpm_v/inference/inference.py b/examples/minicpm_v/inference/inference.py
index 30e8373800..227046c5f5 100644
--- a/examples/minicpm_v/inference/inference.py
+++ b/examples/minicpm_v/inference/inference.py
@@ -1,18 +1,21 @@
-import mindspore as ms
-
from PIL import Image
from transformers import AutoTokenizer
+
+import mindspore as ms
+
from mindone.transformers import MiniCPMV_v2_6
-model = MiniCPMV_v2_6.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, attn_implementation='eager', mindspore_dtype=ms.float32)
+model = MiniCPMV_v2_6.from_pretrained(
+ "openbmb/MiniCPM-V-2_6", trust_remote_code=True, attn_implementation="eager", mindspore_dtype=ms.float32
+)
-tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6", trust_remote_code=True)
-image = Image.open('airplane.jepg').convert('RGB')
+image = Image.open("airplane.jepg").convert("RGB")
# First Round Chat
question = "Tell me the model of this aircraft"
-msgs = [{"role": 'user', 'content': [image, question]}]
+msgs = [{"role": "user", "content": [image, question]}]
answer = model.chat(image=image, msgs=msgs, tokenizer=tokenizer)
print(answer)
@@ -22,4 +25,4 @@
msgs.append({"role": "user", "content": ["Introduce something about Airbus A380."]})
answer = model.chat(image=None, msgs=msgs, tokenizer=tokenizer)
-print(answer)
\ No newline at end of file
+print(answer)
diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py
index ed39ee68b3..dc307e448e 100644
--- a/mindone/transformers/__init__.py
+++ b/mindone/transformers/__init__.py
@@ -37,6 +37,7 @@
GemmaModel,
GemmaPreTrainedModel,
)
+from .models.minicpm_v import MiniCPMV_v2_6, MiniCPMVImageProcessor
from .models.mt5 import (
MT5_PRETRAINED_MODEL_ARCHIVE_LIST,
MT5EncoderModel,
@@ -52,5 +53,3 @@
T5PreTrainedModel,
)
from .models.xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel
-
-from .models.minicpm_v import MiniCPMV_v2_6, MiniCPMVImageProcessor
\ No newline at end of file
diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py
index deb703ce75..06c3e8ff53 100644
--- a/mindone/transformers/feature_extraction_utils.py
+++ b/mindone/transformers/feature_extraction_utils.py
@@ -26,8 +26,6 @@
import numpy as np
-FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
-from .utils.generic import TensorType
from transformers.utils import (
cached_file,
download_url,
@@ -41,10 +39,9 @@
import mindspore
from mindspore import ops
-# if is_mindspore_available():
-# import mindspore
-# from mindspore import ops
+from .utils.generic import TensorType
+FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
logger = logging.get_logger(__name__)
@@ -124,6 +121,7 @@ def as_tensor(value):
is_tensor = ops.is_tensor
else:
+
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
@@ -159,7 +157,9 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non
self[key] = tensor
except Exception as exc: # noqa E722
if key == "overflowing_values":
- raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") from exc
+ raise ValueError(
+ "Unable to create tensor returning overflowing values of different lengths. "
+ ) from exc
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
@@ -205,7 +205,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
return self
-class FeatureExtractionMixin():
+class FeatureExtractionMixin:
"""
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors.
@@ -383,7 +383,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
self.to_json_file(output_feature_extractor_file)
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
-
return [output_feature_extractor_file]
@classmethod
@@ -408,7 +407,7 @@ def get_feature_extractor_dict(
local_files_only = kwargs.pop("local_files_only", False)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
- revision = kwargs.pop('revision', 'main')
+ revision = kwargs.pop("revision", "main")
user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
if from_pipeline is not None:
@@ -474,7 +473,6 @@ def get_feature_extractor_dict(
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
-
return feature_extractor_dict, kwargs
@classmethod
diff --git a/mindone/transformers/image_processing_utils.py b/mindone/transformers/image_processing_utils.py
index 4f2232db16..f743b8e457 100644
--- a/mindone/transformers/image_processing_utils.py
+++ b/mindone/transformers/image_processing_utils.py
@@ -24,17 +24,18 @@
import numpy as np
import requests
-FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
-IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
+from transformers.utils import cached_file, download_url, is_offline_mode, is_remote_url, is_vision_available, logging
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
-from transformers.utils import cached_file, download_url, is_offline_mode, is_remote_url, is_vision_available, logging
if is_vision_available():
from PIL import Image
+FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
+IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
+
logger = logging.get_logger(__name__)
@@ -53,12 +54,15 @@ class BatchFeature(BaseBatchFeature):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""
+
+
# TODO: (Amy) - factor out the common parts of this and the feature extractor
class ImageProcessingMixin:
"""
This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
extractors.
"""
+
_auto_class = None
def __init__(self, **kwargs):
@@ -490,6 +494,7 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
auto_class = auto_class.__name__
import mindnlp.transformers.models.auto as auto_module
+
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
@@ -537,6 +542,7 @@ class BaseImageProcessor(ImageProcessingMixin):
normalize(self, image, mean, std, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Normalize an image using mean and standard deviation.
center_crop(self, image, size, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Center crop an image to a specified size.
"""
+
def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
@@ -826,6 +832,7 @@ def get_size_dict(
)
return size_dict
+
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
"""
Selects the best resolution from a list of possible resolutions based on the original size.
diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py
index 1b68f1aad5..ee7ed99897 100644
--- a/mindone/transformers/image_transforms.py
+++ b/mindone/transformers/image_transforms.py
@@ -2,7 +2,6 @@
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
-
import PIL
import mindspore
@@ -16,6 +15,7 @@
infer_channel_dimension_format,
)
+
def to_channel_dimension_format(
image: np.ndarray,
channel_dim: Union[ChannelDimension, str],
@@ -54,6 +54,7 @@ def to_channel_dimension_format(
return image
+
def _rescale_for_pil_conversion(image):
"""
Detects whether or not the image needs to be rescaled before being converted to a PIL image.
@@ -128,6 +129,7 @@ def to_pil_image(
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
+
def center_crop(
image: np.ndarray,
size: Tuple[int, int],
@@ -225,6 +227,7 @@ def center_crop(
return new_image
+
def normalize(
image: np.ndarray,
mean: Union[float, Iterable[float]],
@@ -284,6 +287,7 @@ def normalize(
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image
+
def rescale(
image: np.ndarray,
scale: float,
@@ -319,4 +323,4 @@ def rescale(
rescaled_image = rescaled_image.astype(dtype)
- return rescaled_image
\ No newline at end of file
+ return rescaled_image
diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py
index e4facfbd03..fab986a5a7 100644
--- a/mindone/transformers/image_utils.py
+++ b/mindone/transformers/image_utils.py
@@ -9,14 +9,22 @@
from .utils.generic import ExplicitEnum
+
class ChannelDimension(ExplicitEnum):
FIRST = "channels_first"
LAST = "channels_last"
+
ImageInput = Union[
- "PIL.Image.Image", np.ndarray, "mindspore.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["mindspore.Tensor"]
+ "PIL.Image.Image",
+ np.ndarray,
+ "mindspore.Tensor",
+ List["PIL.Image.Image"],
+ List[np.ndarray],
+ List["mindspore.Tensor"],
] # noqa
+
def get_channel_dimension_axis(
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
@@ -40,6 +48,7 @@ def get_channel_dimension_axis(
return image.ndim - 1
raise ValueError(f"Unsupported data format: {input_data_format}")
+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
"""
Returns the (height, width) dimensions of the image.
@@ -63,6 +72,7 @@ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> T
else:
raise ValueError(f"Unsupported data format: {channel_dim}")
+
def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
@@ -92,4 +102,4 @@ def infer_channel_dimension_format(
return ChannelDimension.FIRST
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
- raise ValueError("Unable to infer channel dimension format")
\ No newline at end of file
+ raise ValueError("Unable to infer channel dimension format")
diff --git a/mindone/transformers/models/bert/modeling_bert.py b/mindone/transformers/models/bert/modeling_bert.py
index 14c571c3b3..85849bda4b 100644
--- a/mindone/transformers/models/bert/modeling_bert.py
+++ b/mindone/transformers/models/bert/modeling_bert.py
@@ -292,7 +292,7 @@ def construct(
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
- "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "BertSdpaSelfAttention is used but `scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
@@ -339,8 +339,6 @@ def construct(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
# is_causal = (
@@ -1067,12 +1065,13 @@ def construct(
```python
>>> from transformers import AutoTokenizer, BertForPreTraining
- >>> import torch
+ >>> import mindspore as ms
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> inputs = tokenizer("Hello, my dog is cute")
+ >>> inputs = {k:ms.Tensor(v) for k, v in inputs.items()}
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
@@ -1399,14 +1398,14 @@ def construct(
```python
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
- >>> import torch
+ >>> import mindspore as ms
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+ >>> encoding = tokenizer(prompt, next_sentence)
>>> outputs = model(**encoding, labels=ms.Tensor([1]))
>>> logits = outputs.logits
@@ -1517,24 +1516,27 @@ def construct(
if labels is not None:
if self.problem_type is None:
if self.num_labels == 1:
- self.problem_type = "regression"
+ problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == ms.int64 or labels.dtype == ms.int32):
- self.problem_type = "single_label_classification"
+ problem_type = "single_label_classification"
else:
- self.problem_type = "multi_label_classification"
+ problem_type = "multi_label_classification"
+ else:
+ problem_type = self.problem_type
- if self.problem_type == "regression":
+ if problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
- elif self.problem_type == "single_label_classification":
+ elif problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).int())
- elif self.problem_type == "multi_label_classification":
+ elif problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
+
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
diff --git a/mindone/transformers/models/minicpm_v/__init__.py b/mindone/transformers/models/minicpm_v/__init__.py
index c973fe341f..3d30bf0084 100644
--- a/mindone/transformers/models/minicpm_v/__init__.py
+++ b/mindone/transformers/models/minicpm_v/__init__.py
@@ -1,2 +1,2 @@
+from .image_processing_minicpmv import MiniCPMVImageProcessor
from .modeling_minicpmv import MiniCPMV_v2_6
-from .image_processing_minicpmv import MiniCPMVImageProcessor
\ No newline at end of file
diff --git a/mindone/transformers/models/minicpm_v/configuration_minicpm.py b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
index 063cfee91b..db1f383fc0 100644
--- a/mindone/transformers/models/minicpm_v/configuration_minicpm.py
+++ b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
@@ -45,7 +45,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
return cls.from_dict(config_dict, **kwargs)
-
class MiniCPMVConfig(Qwen2Config):
model_type = "minicpmv"
keys_to_ignore_at_inference = ["past_key_values"]
@@ -58,7 +57,7 @@ class MiniCPMVConfig(Qwen2Config):
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
- "attn_implementation": "flash_attention"
+ "attn_implementation": "flash_attention",
}
def __init__(
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
index 7626964f1b..1aad29b744 100644
--- a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -5,7 +5,6 @@
import PIL
import PIL.Image
import PIL.ImageSequence
-from ...image_processing_utils import BaseImageProcessor, BatchFeature
from PIL import Image
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
@@ -23,6 +22,8 @@
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+
def recursive_converter(converter, value):
if isinstance(value, list):
@@ -104,12 +105,7 @@ def cast_tensor(v):
class MiniCPMVImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
- def __init__(
- self,
- max_slice_nums=9,
- scale_resolution=448,
- patch_size=14,
- **kwargs):
+ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
super().__init__(**kwargs)
self.max_slice_nums = max_slice_nums
self.scale_resolution = scale_resolution
@@ -131,14 +127,9 @@ def __init__(
def ensure_divide(self, length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
- def find_best_resize(self,
- original_size,
- scale_resolution,
- patch_size,
- allow_upscale=False):
+ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
- if (width * height >
- scale_resolution * scale_resolution) or allow_upscale:
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
@@ -146,12 +137,7 @@ def find_best_resize(self,
best_height = self.ensure_divide(height, patch_size)
return (best_width, best_height)
- def get_refine_size(self,
- original_size,
- grid,
- scale_resolution,
- patch_size,
- allow_upscale=False):
+ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
grid_x, grid_y = grid
@@ -161,10 +147,9 @@ def get_refine_size(self,
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
- best_grid_size = self.find_best_resize((grid_width, grid_height),
- scale_resolution,
- patch_size,
- allow_upscale=allow_upscale)
+ best_grid_size = self.find_best_resize(
+ (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
+ )
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
@@ -182,9 +167,7 @@ def split_to_patches(self, image, grid):
patches.append(images)
return patches
- def slice_image(
- self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
- ):
+ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
original_size = image.size
source_image = None
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
@@ -192,9 +175,7 @@ def slice_image(
if best_grid is None:
# dont need to slice, upsample
- best_size = self.find_best_resize(
- original_size, scale_resolution, patch_size, allow_upscale=True
- )
+ best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
else:
# source image, down-sampling and ensure divided by patch_size
@@ -212,9 +193,7 @@ def get_grid_placeholder(self, grid):
if grid is None:
return ""
slice_image_placeholder = (
- self.slice_start_token
- + self.unk_token * self.image_feature_size
- + self.slice_end_token
+ self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
)
cols = grid[0]
@@ -241,10 +220,7 @@ def get_sliced_images(self, image, max_slice_nums=None):
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
assert max_slice_nums > 0
source_image, patches, sliced_grid = self.slice_image(
- image,
- max_slice_nums, # default: 9
- self.scale_resolution, # default: 448
- self.patch_size # default: 14
+ image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
)
slice_images.append(source_image)
@@ -290,11 +266,7 @@ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=No
assert max_slice_nums > 0
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
- image_placeholder = (
- self.im_start_token
- + self.unk_token * self.image_feature_size
- + self.im_end_token
- )
+ image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
if use_image_id:
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
@@ -349,11 +321,7 @@ def reshape_by_patch(self, image):
w = image.shape[2]
image = image.reshape(1, c, h, w)
- patches = ops.unfold(
- image,
- (patch_size, patch_size),
- stride=(patch_size, patch_size)
- )
+ patches = ops.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
image = image.squeeze(axis=0)
@@ -362,12 +330,12 @@ def reshape_by_patch(self, image):
return patches.numpy()
def preprocess(
- self,
- images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
- do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
- max_slice_nums: int = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
+ self,
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
+ max_slice_nums: int = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
) -> MiniCPMVBatchFeature:
if isinstance(images, Image.Image):
images_list = [[images]]
@@ -412,7 +380,8 @@ def preprocess(
for slice_image in image_patches:
new_images.append(self.reshape_by_patch(slice_image))
tgt_sizes.append(
- np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
+ )
if tgt_sizes:
tgt_sizes = np.vstack(tgt_sizes)
@@ -422,7 +391,7 @@ def preprocess(
tgt_sizes_list.append(tgt_sizes)
return MiniCPMVBatchFeature(
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
- tensor_type=return_tensors
+ tensor_type=return_tensors,
)
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
index 43d35d0990..92beb0c126 100644
--- a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -2,23 +2,21 @@
import math
from copy import deepcopy
from threading import Thread
-from typing import List, Optional
-from transformers import TextIteratorStreamer
from PIL import Image
+from transformers import TextIteratorStreamer
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Parameter, Tensor, _no_grad, nn, ops
from ..qwen2 import Qwen2ForCausalLM, Qwen2PreTrainedModel
from .configuration_minicpm import MiniCPMVConfig
+from .image_processing_minicpmv import MiniCPMVImageProcessor
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmv import MiniCPMVProcessor
-from .image_processing_minicpmv import MiniCPMVImageProcessor
from .resampler import Resampler
-from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
+# from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
-from mindspore import _no_grad
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
@@ -34,21 +32,21 @@ def __init__(self, config):
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.processor = None
- self.terminators = ['<|im_end|>', '<|endoftext|>']
+ self.terminators = ["<|im_end|>", "<|endoftext|>"]
def init_vision_module(self):
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
- if self.config._attn_implementation == 'flash_attention_2':
- self.config.vision_config._attn_implementation = 'flash_attention_2'
+ if self.config._attn_implementation == "flash_attention_2":
+ self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not suport sdpa
- self.config.vision_config._attn_implementation = 'eager'
+ self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
- setattr(model, 'embed_dim', model.embeddings.embed_dim)
- setattr(model, 'patch_size', model.embeddings.patch_size)
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
return model
@@ -58,7 +56,7 @@ def init_resampler(self, embed_dim, vision_dim):
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
- adaptive=True
+ adaptive=True,
)
def get_input_embeddings(self):
@@ -80,11 +78,11 @@ def get_decoder(self):
return self.llm
def get_vllm_embedding(self, data):
- if 'vision_hidden_states' not in data:
+ if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.embedding_table.dtype
device = None
- tgt_sizes = data['tgt_sizes']
- pixel_values_list = data['pixel_values']
+ tgt_sizes = data["tgt_sizes"]
+ pixel_values_list = data["pixel_values"]
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
@@ -107,7 +105,16 @@ def get_vllm_embedding(self, data):
max_length_w = max([i.shape[1] for i in all_pixel_values])
for i in range(len(all_pixel_values)):
if all_pixel_values[i].shape[0] < max_length_h or all_pixel_values[i].shape[1] < max_length_w:
- all_pixel_values[i] = ops.pad(all_pixel_values[i], (0, max_length_w - all_pixel_values[i].shape[1], 0, max_length_h - all_pixel_values[i].shape[0]), value=0.0)
+ all_pixel_values[i] = ops.pad(
+ all_pixel_values[i],
+ (
+ 0,
+ max_length_w - all_pixel_values[i].shape[1],
+ 0,
+ max_length_h - all_pixel_values[i].shape[0],
+ ),
+ value=0.0,
+ )
all_pixel_values = ops.stack(all_pixel_values)
B, L, _ = all_pixel_values.shape
@@ -115,7 +122,7 @@ def get_vllm_embedding(self, data):
patch_attn_mask = ops.zeros(Tensor((B, 1, int(max_patches))), dtype=ms.bool_)
for i in range(B):
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.astype(dtype)
@@ -124,28 +131,33 @@ def get_vllm_embedding(self, data):
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
- tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
+ tmp_hs = self.vpm(
+ all_pixel_values[start_idx:end_idx],
+ patch_attention_mask=patch_attn_mask[start_idx:end_idx],
+ tgt_sizes=tgt_sizes[start_idx:end_idx],
+ ).last_hidden_state
hs.append(tmp_hs)
vision_embedding = ops.cat(hs, axis=0)
else:
- vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
+ vision_embedding = self.vpm(
+ all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
+ ).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
- vision_hidden_states.append(vision_embedding[start: start + img_cnt])
+ vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
- else: # no image
+ else: # no image
if self.training:
- dummy_image = ops.zeros(
- (1, 3, 224, 224),
- dtype=dtype
- )
- tgt_sizes = ms.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).astype(ms.int32)
+ dummy_image = ops.zeros((1, 3, 224, 224), dtype=dtype)
+ tgt_sizes = ms.Tensor(
+ [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
+ ).astype(ms.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
@@ -153,15 +165,16 @@ def get_vllm_embedding(self, data):
vision_hidden_states.append(dummy_feature)
else:
- vision_hidden_states = data['vision_hidden_states']
+ vision_hidden_states = data["vision_hidden_states"]
- if hasattr(self.llm.config, 'scale_emb'):
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
+ if hasattr(self.llm.config, "scale_emb"):
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
else:
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
- vision_hidden_states = [i.astype(vllm_embedding.dtype) if isinstance(
- i, ms.Tensor) else i for i in vision_hidden_states]
+ vision_hidden_states = [
+ i.astype(vllm_embedding.dtype) if isinstance(i, ms.Tensor) else i for i in vision_hidden_states
+ ]
# bs = len(data['input_ids'])
# for i in range(bs):
@@ -188,12 +201,7 @@ def construct(self, data, **kwargs):
position_ids = position_ids.long()
with _no_grad():
- return self.llm(
- input_ids=None,
- position_ids=position_ids,
- inputs_embeds=vllm_embedding,
- **kwargs
- )
+ return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
@@ -202,7 +210,7 @@ def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, *
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
- **kwargs
+ **kwargs,
)
if decode_text:
return self._decode_text(output, tokenizer)
@@ -212,10 +220,10 @@ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
- 'inputs_embeds': inputs_embeds,
- 'pad_token_id': 0,
- 'eos_token_id': terminators,
- 'streamer': streamer
+ "inputs_embeds": inputs_embeds,
+ "pad_token_id": 0,
+ "eos_token_id": terminators,
+ "streamer": streamer,
}
generation_kwargs.update(kwargs)
@@ -248,7 +256,7 @@ def generate(
return_vision_hidden_states=False,
stream=False,
decode_text=False,
- **kwargs
+ **kwargs,
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
@@ -260,7 +268,7 @@ def generate(
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
- model_inputs['tgt_sizes'] = tgt_sizes
+ model_inputs["tgt_sizes"] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
@@ -273,7 +281,9 @@ def generate(
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
else:
- result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
+ result = self._decode(
+ model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
+ )
if return_vision_hidden_states:
return result, vision_hidden_states
@@ -291,11 +301,11 @@ def chat(
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
- system_prompt='',
+ system_prompt="",
stream=False,
max_slice_nums=None,
use_image_id=None,
- **kwargs
+ **kwargs,
):
if isinstance(msgs[0], list):
batched = True
@@ -318,11 +328,21 @@ def chat(
self.processor = MiniCPMVProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
processor = self.processor
- assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.query_num == processor.image_processor.image_feature_size
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.patch_size == processor.image_processor.patch_size
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.use_image_id == processor.image_processor.use_image_id
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.slice_mode == processor.image_processor.slice_mode
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
prompts_lists = []
input_images_lists = []
@@ -356,10 +376,12 @@ def chat(
msg["content"] = "\n".join(cur_msgs)
if system_prompt:
- sys_msg = {'role': 'system', 'content': system_prompt}
+ sys_msg = {"role": "system", "content": system_prompt}
copy_msgs = [sys_msg] + copy_msgs
- prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
+ prompts_lists.append(
+ processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
+ )
input_images_lists.append(images)
inputs = processor(
@@ -369,7 +391,7 @@ def chat(
use_image_id=use_image_id,
return_tensors="ms",
max_length=max_inp_length,
- image_processor=image_processor
+ image_processor=image_processor,
)
if sampling:
@@ -378,7 +400,7 @@ def chat(
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
- "repetition_penalty": 1.05
+ "repetition_penalty": 1.05,
}
else:
generation_config = {
@@ -387,11 +409,9 @@ def chat(
}
if min_new_tokens > 0:
- generation_config['min_new_tokens'] = min_new_tokens
+ generation_config["min_new_tokens"] = min_new_tokens
- generation_config.update(
- (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
- )
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
inputs.pop("image_sizes")
# with torch.inference_mode():
@@ -402,15 +422,17 @@ def chat(
vision_hidden_states=vision_hidden_states,
stream=stream,
decode_text=True,
- **generation_config
+ **generation_config,
)
if stream:
+
def stream_gen():
for text in res:
for term in self.terminators:
- text = text.replace(term, '')
+ text = text.replace(term, "")
yield text
+
return stream_gen()
else:
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
index 9b946ad76d..b6ebf327a7 100644
--- a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -20,7 +20,7 @@
import os
import warnings
from dataclasses import dataclass
-from typing import Any, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import numpy as np
from transformers.configuration_utils import PretrainedConfig
@@ -42,6 +42,7 @@
logger = logging.get_logger(__name__)
+
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
@@ -229,13 +230,16 @@ def trunc_normal_tf_(
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
+
def _calculate_fan_in_and_fan_out(arr):
# 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。
shape = arr.shape
dimensions = len(shape)
if dimensions < 2:
- raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
- " 2 dimensions, but got dimensions {}.".format(dimensions))
+ raise ValueError(
+ "'fan_in' and 'fan_out' can not be computed for arr with fewer than"
+ " 2 dimensions, but got dimensions {}.".format(dimensions)
+ )
if dimensions == 2: # Linear
fan_in = shape[1]
fan_out = shape[0]
@@ -249,6 +253,7 @@ def _calculate_fan_in_and_fan_out(arr):
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
@@ -333,7 +338,9 @@ def __init__(self, config: SiglipVisionConfig):
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- def construct(self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor]=None) -> ms.Tensor:
+ def construct(
+ self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor] = None
+ ) -> ms.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values)
@@ -388,7 +395,7 @@ def __init__(self, config):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
- self.scale = self.head_dim ** -0.5
+ self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
@@ -416,8 +423,8 @@ def construct(
k_v_seq_len = key_states.shape[-2]
- query_states = ops.mul(query_states, self.scale ** 0.5)
- key_states = ops.mul(key_states, self.scale ** 0.5)
+ query_states = ops.mul(query_states, self.scale**0.5)
+ key_states = ops.mul(key_states, self.scale**0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -621,9 +628,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
- cu_seqlens_q = ops.arange(
- batch_size + 1, dtype=ms.int32
- ) # There is a memcpy here, that is very bad.
+ cu_seqlens_q = ops.arange(batch_size + 1, dtype=ms.int32) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
@@ -640,22 +645,21 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
+
class SiglipFlashAttention(SiglipAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Hack to make sure we don't use a causal mask
dropout_rate = self.dropout if self.training else 0.0
self.flash_attention = FlashAttention(
- scale_value=self.head_dim**-0.5,
- head_num=self.head_dim,
- input_layout="BSH",
- keep_prob=1-dropout_rate
+ scale_value=self.head_dim**-0.5, head_num=self.head_dim, input_layout="BSH", keep_prob=1 - dropout_rate
)
def construct(
@@ -729,9 +733,7 @@ def construct(
value_states = value_states.to(target_dtype)
# implement flash attention
- attn_output = self.flash_attention(
- query_states, key_states, value_states, None, None, None, attention_mask
- )[3]
+ attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, attention_mask)[3]
# attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
@@ -741,6 +743,7 @@ def construct(
return attn_output, attn_weights
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Cell):
def __init__(self, config):
@@ -764,11 +767,7 @@ def __init__(self, config: SiglipVisionConfig):
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention = config._attn_implementation == "flash_attention"
- self.self_attn = (
- SiglipAttention(config)
- if not self._use_flash_attention
- else SiglipFlashAttention(config)
- )
+ self.self_attn = SiglipAttention(config) if not self._use_flash_attention else SiglipFlashAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
@@ -975,9 +974,8 @@ def construct(
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
- )
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
+
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
@@ -1001,11 +999,9 @@ def __init__(self, config: SiglipVisionConfig):
# recompute
# self.encoder.recompute()
-
def get_input_embeddings(self) -> nn.Cell:
return self.embeddings.patch_embedding
-
def construct(
self,
pixel_values,
@@ -1035,14 +1031,16 @@ def construct(
dtype=ms.bool_,
)
- hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes)
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes
+ )
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not ops.any(~patch_attention_mask):
- attention_mask=None
+ attention_mask = None
else:
attention_mask = (
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
diff --git a/mindone/transformers/models/minicpm_v/processing_minicpmv.py b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
index 36438fd5f6..2cc0a57140 100644
--- a/mindone/transformers/models/minicpm_v/processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
@@ -17,19 +17,20 @@
"""
import re
-from typing import Any, Dict, List, Optional, Union
+from typing import List, Optional, Union
import numpy as np
-from transformers.utils import TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.utils import TensorType
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Tensor, ops
from ...processing_utils import ProcessorMixin
from .image_processing_minicpmv import MiniCPMVBatchFeature, MiniCPMVImageProcessor
+
class MiniCPMVProcessor(ProcessorMixin):
r"""
Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
@@ -61,12 +62,21 @@ def __call__(
use_image_id: bool = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
image_processor=None,
- **kwargs
+ **kwargs,
) -> MiniCPMVBatchFeature:
-
if images is not None:
- image_inputs = image_processor.preprocess(images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors)
- return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, image_processor=image_processor, **kwargs)
+ image_inputs = image_processor.preprocess(
+ images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
+ )
+ return self._convert_images_texts_to_inputs(
+ image_inputs,
+ text,
+ max_slice_nums=max_slice_nums,
+ use_image_id=use_image_id,
+ max_length=max_length,
+ image_processor=image_processor,
+ **kwargs,
+ )
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
@@ -96,13 +106,13 @@ def decode(self, *args, **kwargs):
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
result = result[1:]
- if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
+ if result[-1] == self.tokenizer.eos_id or (
+ hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
+ ):
result = result[:-1]
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
- def _convert(
- self, input_str, max_inp_length: Optional[int] = None
- ):
+ def _convert(self, input_str, max_inp_length: Optional[int] = None):
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
input_ids = self.tokenizer.encode(input_str)
else:
@@ -140,19 +150,21 @@ def _convert(
return input_ids, image_bounds
def _convert_images_texts_to_inputs(
- self,
- images,
- texts: Union[str, List[str]],
- truncation=None,
- max_length=None,
- max_slice_nums=None,
- use_image_id=None,
- return_tensors=None,
- image_processor=None,
- **kwargs
- ):
+ self,
+ images,
+ texts: Union[str, List[str]],
+ truncation=None,
+ max_length=None,
+ max_slice_nums=None,
+ use_image_id=None,
+ return_tensors=None,
+ image_processor=None,
+ **kwargs,
+ ):
if images is None or not len(images):
- model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs)
+ model_inputs = self.tokenizer(
+ texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
+ )
return MiniCPMVBatchFeature(data={**model_inputs})
pattern = "(./)"
@@ -168,33 +180,32 @@ def _convert_images_texts_to_inputs(
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
- final_text = final_text + text_chunks[i] + \
- image_processor.get_slice_image_placeholder(
- image_sizes[index][i],
- i,
- max_slice_nums,
- use_image_id
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[index][i], i, max_slice_nums, use_image_id
)
+ )
final_text += text_chunks[-1]
input_ids, image_bounds = self._convert(final_text, max_length)
input_ids_list.append(input_ids)
image_bounds_list.append(image_bounds)
- padded_input_ids, padding_lengths = self.pad(
- input_ids_list,
- padding_side="left"
- )
+ padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
for i, length in enumerate(padding_lengths):
image_bounds_list[i] = image_bounds_list[i] + length
attention_mask = padded_input_ids.ne(0)
- return MiniCPMVBatchFeature(data={
- "input_ids": padded_input_ids,
- "attention_mask": attention_mask,
- "pixel_values": images,
- "image_sizes": image_sizes,
- "image_bound": image_bounds_list,
- "tgt_sizes": tgt_sizes
- })
+ return MiniCPMVBatchFeature(
+ data={
+ "input_ids": padded_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": images,
+ "image_sizes": image_sizes,
+ "image_bound": image_bounds_list,
+ "tgt_sizes": tgt_sizes,
+ }
+ )
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
@@ -203,7 +214,6 @@ def model_input_names(self):
image_processor_input_names = MiniCPMVImageProcessor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(inputs[0], list):
@@ -232,10 +242,7 @@ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
return ops.stack([item for item in items], axis=0), [0] * batch_size
tensor = ops.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
- tensor = (
- ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
- + padding_value
- )
+ tensor = ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
padding_length = []
for i, item in enumerate(items):
diff --git a/mindone/transformers/models/minicpm_v/resampler.py b/mindone/transformers/models/minicpm_v/resampler.py
index 0cea2d7e6e..d414422479 100644
--- a/mindone/transformers/models/minicpm_v/resampler.py
+++ b/mindone/transformers/models/minicpm_v/resampler.py
@@ -56,10 +56,10 @@ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000 ** omega # (D/2,)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
- out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
@@ -77,14 +77,14 @@ class Resampler(nn.Cell):
"""
def __init__(
- self,
- num_queries,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
- adaptive=False,
- max_size=(70, 70),
+ self,
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim=None,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
+ adaptive=False,
+ max_size=(70, 70),
):
super().__init__()
self.num_queries = num_queries
@@ -105,7 +105,7 @@ def __init__(
self.ln_kv = norm_layer((embed_dim,))
self.ln_post = norm_layer((embed_dim,))
- self.proj = Parameter((embed_dim ** -0.5) * ops.randn(embed_dim, embed_dim))
+ self.proj = Parameter((embed_dim**-0.5) * ops.randn(embed_dim, embed_dim))
self._set_2d_pos_cache(self.max_size)
@@ -125,7 +125,7 @@ def _adjust_pos_cache(self, tgt_sizes):
def _init_weights(self, m):
if isinstance(m, nn.Dense):
- trunc_normal_(m.weight, std=.02)
+ trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Dense) and m.bias is not None:
Zero(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
@@ -150,8 +150,9 @@ def construct(self, x, tgt_sizes=None):
tgt_h, tgt_w = tgt_sizes[i]
shape_0 = tgt_h * tgt_w
pos_embed.append(
- self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)) # patches * D
- key_padding_mask[i, patch_len[i]:] = True
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)
+ ) # patches * D
+ key_padding_mask[i, patch_len[i] :] = True
# FIXME how to replace torch.nn.utils.rnn.pad_sequence
# pos_embed = torch.nn.utils.rnn.pad_sequence(
@@ -160,9 +161,11 @@ def construct(self, x, tgt_sizes=None):
max_length_w = max([i.shape[1] for i in pos_embed])
for i in range(len(pos_embed)):
if pos_embed[i].shape[0] < max_length_h or pos_embed[i].shape[1] < max_length_w:
- pos_embed[i] = ops.pad(pos_embed[i], (
- 0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
- value=0.0)
+ pos_embed[i] = ops.pad(
+ pos_embed[i],
+ (0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
+ value=0.0,
+ )
pos_embed = ops.stack(pos_embed)
pos_embed = pos_embed.permute(1, 0, 2)
@@ -175,7 +178,8 @@ def construct(self, x, tgt_sizes=None):
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
- key_padding_mask=key_padding_mask)[0]
+ key_padding_mask=key_padding_mask,
+ )[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
@@ -188,25 +192,38 @@ def _repeat(self, query, N: int):
class MultiheadAttention(nn.MultiheadAttention):
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=None):
- super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first,
- dtype)
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ dtype=None,
+ ):
+ super().__init__(
+ embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, dtype
+ )
# rewrite out_proj layer,with nn.Linear
self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias)
def construct(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
- why_not_fast_path = ''
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ why_not_fast_path = ""
# if ((attn_mask is not None and torch.is_floating_point(attn_mask))
# or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
# why_not_fast_path = "floating-point masks are not supported for fast path."
@@ -218,7 +235,7 @@ def construct(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype
+ target_type=query.dtype,
)
attn_mask = _canonical_mask(
@@ -238,12 +255,16 @@ def construct(
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ why_not_fast_path = (
+ f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ )
elif self.in_proj_weight is None:
why_not_fast_path = "in_proj_weight was None"
elif query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ why_not_fast_path = (
+ f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ )
elif self.training:
why_not_fast_path = "training is enabled"
elif (self.num_heads % 2) != 0:
@@ -304,7 +325,8 @@ def construct(
merged_mask,
need_weights,
average_attn_weights,
- mask_type)
+ mask_type,
+ )
# any_nested = query.is_nested or key.is_nested or value.is_nested
# assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
@@ -323,62 +345,84 @@ def construct(
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
training=self.training,
- key_padding_mask=key_padding_mask, need_weights=need_weights,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
- is_causal=is_causal)
+ is_causal=is_causal,
+ )
else:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
- is_causal=is_causal)
+ is_causal=is_causal,
+ )
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def multi_head_attention_forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- embed_dim_to_check: int,
- num_heads: int,
- in_proj_weight: Optional[Tensor],
- in_proj_bias: Optional[Tensor],
- bias_k: Optional[Tensor],
- bias_v: Optional[Tensor],
- add_zero_attn: bool,
- dropout_p: float,
- out_proj_weight: Tensor,
- out_proj_bias: Optional[Tensor],
- training: bool = True,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- use_separate_proj_weight: bool = False,
- q_proj_weight: Optional[Tensor] = None,
- k_proj_weight: Optional[Tensor] = None,
- v_proj_weight: Optional[Tensor] = None,
- static_k: Optional[Tensor] = None,
- static_v: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False,
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[Tensor],
+ in_proj_bias: Optional[Tensor],
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Optional[Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
# FIXME: logic passed
@@ -435,7 +479,7 @@ def multi_head_attention_forward(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype
+ target_type=query.dtype,
)
if is_causal and attn_mask is None:
@@ -466,18 +510,20 @@ def multi_head_attention_forward(
# longer causal.
is_causal = False
- assert embed_dim == embed_dim_to_check, \
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ assert (
+ embed_dim == embed_dim_to_check
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, ms.Tensor):
# embed_dim can be a tensor when JIT tracing
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
- assert key.shape[:2] == value.shape[:2], \
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ assert (
+ key.shape[:2] == value.shape[:2]
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
@@ -505,13 +551,15 @@ def multi_head_attention_forward(
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
+ )
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
+ )
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
@@ -539,19 +587,23 @@ def multi_head_attention_forward(
k = k.view(k.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_k.shape[0] == bsz * num_heads, \
- f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
- assert static_k.shape[2] == head_dim, \
- f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
+ assert (
+ static_k.shape[0] == bsz * num_heads
+ ), f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
+ assert (
+ static_k.shape[2] == head_dim
+ ), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_v.shape[0] == bsz * num_heads, \
- f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
- assert static_v.shape[2] == head_dim, \
- f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
+ assert (
+ static_v.shape[0] == bsz * num_heads
+ ), f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
+ assert (
+ static_v.shape[2] == head_dim
+ ), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
v = static_v
# add zero attention along batch dimension (now first)
@@ -569,10 +621,15 @@ def multi_head_attention_forward(
# merge key padding and attention masks
if key_padding_mask is not None:
- assert key_padding_mask.shape == (bsz, src_len), \
- f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
- key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
- expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
+ assert key_padding_mask.shape == (
+ bsz,
+ src_len,
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = (
+ key_padding_mask.view(bsz, 1, 1, src_len)
+ .expand(-1, num_heads, -1, -1)
+ .reshape(bsz * num_heads, 1, src_len)
+ )
if attn_mask is None:
attn_mask = key_padding_mask
else:
@@ -641,8 +698,14 @@ def multi_head_attention_forward(
return attn_output, None
-def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
- key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
+def _mha_shape_check(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor],
+ num_heads: int,
+):
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
# and returns if the input is batched or not.
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
@@ -651,58 +714,65 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
if query.dim() == 3:
# Batched Inputs
is_batched = True
- assert key.dim() == 3 and value.dim() == 3, \
- ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ assert key.dim() == 3 and value.dim() == 3, (
+ "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
+ )
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 2, \
- ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
- f" but found {key_padding_mask.dim()}-D tensor instead")
+ assert key_padding_mask.dim() == 2, (
+ "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead"
+ )
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), \
- ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead")
+ assert attn_mask.dim() in (2, 3), (
+ "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead"
+ )
elif query.dim() == 2:
# Unbatched Inputs
is_batched = False
- assert key.dim() == 2 and value.dim() == 2, \
- ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ assert key.dim() == 2 and value.dim() == 2, (
+ "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
+ )
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 1, \
- ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
- f" but found {key_padding_mask.dim()}-D tensor instead")
+ assert key_padding_mask.dim() == 1, (
+ "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead"
+ )
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), \
- ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead")
+ assert attn_mask.dim() in (2, 3), (
+ "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead"
+ )
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
- assert attn_mask.shape == expected_shape, \
- (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
+ assert (
+ attn_mask.shape == expected_shape
+ ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
else:
raise AssertionError(
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
+ )
return is_batched
def _canonical_mask(
- mask: Optional[Tensor],
- mask_name: str,
- other_type: Optional,
- other_name: str,
- target_type: None,
- check_other: bool = True,
+ mask: Optional[Tensor],
+ mask_name: str,
+ other_type: Optional,
+ other_name: str,
+ target_type: None,
+ check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = ops.is_floating_point(mask)
if _mask_dtype != ms.bool_ and not _mask_is_float:
- raise AssertionError(
- f"only bool and floating types of {mask_name} are supported")
+ raise AssertionError(f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
@@ -710,10 +780,7 @@ def _canonical_mask(
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
- mask = (
- ops.zeros_like(mask, dtype=target_type)
- .masked_fill(mask, float("-inf"))
- )
+ mask = ops.zeros_like(mask, dtype=target_type).masked_fill(mask, float("-inf"))
return mask
@@ -726,11 +793,11 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional:
def _in_projection_packed(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w: Tensor,
- b: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w: Tensor,
+ b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
@@ -786,15 +853,15 @@ def _in_projection_packed(
def _in_projection(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w_q: Tensor,
- w_k: Tensor,
- w_v: Tensor,
- b_q: Optional[Tensor] = None,
- b_k: Optional[Tensor] = None,
- b_v: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w_q: Tensor,
+ w_k: Tensor,
+ w_v: Tensor,
+ b_q: Optional[Tensor] = None,
+ b_k: Optional[Tensor] = None,
+ b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Performs the in-projection step of the attention operation. This is simply
diff --git a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
index 1707f105fb..e41bff8bca 100644
--- a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
+++ b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
@@ -57,7 +57,7 @@ def im_id_end_id(self):
@property
def newline_id(self):
- return self.convert_tokens_to_ids('\n')
+ return self.convert_tokens_to_ids("\n")
@staticmethod
def escape(text: str) -> str:
@@ -67,4 +67,5 @@ def escape(text: str) -> str:
def unescape(text: str) -> str:
return text
+
AutoTokenizer.register("MiniCPMVTokenizerFast", MiniCPMVTokenizerFast)
diff --git a/mindone/transformers/models/qwen2/__init__.py b/mindone/transformers/models/qwen2/__init__.py
index be2d5916fd..aa7e109cf0 100644
--- a/mindone/transformers/models/qwen2/__init__.py
+++ b/mindone/transformers/models/qwen2/__init__.py
@@ -21,7 +21,6 @@
}
-
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
_import_structure["modeling_qwen2"] = [
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
index b813e208a4..839a671f25 100644
--- a/mindone/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -23,14 +23,16 @@
from typing import List, Optional, Tuple, Union
import numpy as np
+from transformers import logging
import mindspore as ms
-from mindspore import nn, ops, Tensor, Parameter
+from mindspore import Parameter, Tensor, nn, ops
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -40,11 +42,6 @@
from ...modeling_utils import MSPreTrainedModel
from .configuration_qwen2 import Qwen2Config
-from transformers import logging
-
-from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
logger = logging.get_logger(__name__)
@@ -70,7 +67,6 @@ def dtype_to_min(dtype):
raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}")
-
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: ms.Tensor,
@@ -126,9 +122,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = causal_mask.masked_fill(padding_mask, min_dtype)
else:
causal_mask = ops.cat(
- [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
- ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)],
- axis=-1
+ [
+ ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
+ ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length),
+ ],
+ axis=-1,
)
return causal_mask
@@ -167,9 +165,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.inv_freq = inv_freq
# Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=None, dtype=ms.float32
- )
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=None, dtype=ms.float32)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
@@ -300,7 +296,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
base=self.rope_theta,
)
- self.scale = self.head_dim ** -0.5
+ self.scale = self.head_dim**-0.5
def construct(
self,
@@ -342,8 +338,8 @@ def construct(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- query_states = ops.mul(query_states, self.scale ** 0.5)
- key_states = ops.mul(key_states, self.scale ** 0.5)
+ query_states = ops.mul(query_states, self.scale**0.5)
+ key_states = ops.mul(key_states, self.scale**0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -859,7 +855,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -907,9 +902,7 @@ def construct(
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = ops.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
- )
+ cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -1036,11 +1029,7 @@ def _update_causal_mask(
batch_size=input_tensor.shape[0],
)
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and not output_attentions
- ):
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and not output_attentions:
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
@@ -1079,7 +1068,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1191,7 +1179,7 @@ def prepare_inputs_for_generation(
# input_ids = input_ids[:, :cache_position.shape[0]]
if inputs_embeds is not None: # Exception 1
if 0 not in input_ids.shape:
- input_ids = input_ids[:, -cache_position.shape[0]:]
+ input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = ops.index_select(input_ids, -1, cache_position)
@@ -1245,7 +1233,6 @@ def prepare_inputs_for_generation(
return model_inputs
-
class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1262,7 +1249,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1353,7 +1339,6 @@ def construct(
)
-
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
@@ -1378,7 +1363,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
-
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -1429,4 +1413,4 @@ def construct(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- )
\ No newline at end of file
+ )
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2.py b/mindone/transformers/models/qwen2/tokenization_qwen2.py
index c5cff300a2..b13046fbbf 100644
--- a/mindone/transformers/models/qwen2/tokenization_qwen2.py
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2.py
@@ -49,9 +49,7 @@ def bytes_to_unicode():
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
- bs = (
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
- )
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py
index c258ae580d..adc6cc1222 100644
--- a/mindone/transformers/processing_utils.py
+++ b/mindone/transformers/processing_utils.py
@@ -36,10 +36,12 @@
"AutoImageProcessor": "ImageProcessingMixin",
}
+
class ProcessorMixin:
"""
This is a mixin used to provide saving/loading functionality for all processor classes.
"""
+
attributes = ["feature_extractor", "tokenizer"]
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
@@ -89,6 +91,7 @@ def __init__(self, *args, **kwargs):
proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
elif class_name == "MiniCPMVImageProcessor":
from mindone.transformers import MiniCPMVImageProcessor
+
proper_class = MiniCPMVImageProcessor
else:
proper_class = getattr(transformers_module, class_name)
@@ -157,14 +160,14 @@ def save_pretrained(self, save_directory, **kwargs):
@classmethod
def from_pretrained(
- cls,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- local_files_only: bool = False,
- token: Optional[Union[str, bool]] = None,
- revision: str = "main",
- **kwargs,
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
):
r"""
Instantiate a processor associated with a pretrained model.
@@ -269,6 +272,7 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
attribute_class = classes[0]
elif class_name == "MiniCPMVImageProcessor":
from mindone.transformers import MiniCPMVImageProcessor
+
attribute_class = MiniCPMVImageProcessor
else:
attribute_class = getattr(transformers_module, class_name)
diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py
index dd5f2c6f78..acd2a77516 100644
--- a/mindone/transformers/utils/generic.py
+++ b/mindone/transformers/utils/generic.py
@@ -1,7 +1,7 @@
import inspect
-
from enum import Enum
+
def can_return_loss(model_class):
"""
Check if a given model can return loss.
@@ -33,6 +33,7 @@ def find_labels(model_class):
else:
return [p for p in signature.parameters if "label" in p]
+
class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
@@ -44,6 +45,7 @@ def _missing_(cls, value):
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
+
class TensorType(ExplicitEnum):
"""
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
@@ -55,4 +57,4 @@ class TensorType(ExplicitEnum):
NUMPY = "np"
JAX = "jax"
MLX = "mlx"
- MINDSPORE = "ms"
\ No newline at end of file
+ MINDSPORE = "ms"
From 6ba878217d561179c055849a09389a18c38a8034 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Tue, 19 Nov 2024 20:41:12 +0800
Subject: [PATCH 3/9] feat(minicpm-v): Support MiniCPM-V inference pipeline
---
mindone/transformers/feature_extraction_utils.py | 1 -
mindone/transformers/image_processing_utils.py | 1 -
mindone/transformers/image_utils.py | 7 +------
.../models/minicpm_v/image_processing_minicpmv.py | 5 +----
mindone/transformers/models/minicpm_v/modeling_minicpmv.py | 1 +
.../transformers/models/minicpm_v/modeling_navit_siglip.py | 3 +--
mindone/transformers/models/qwen2/modeling_qwen2.py | 3 +--
7 files changed, 5 insertions(+), 16 deletions(-)
diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py
index 06c3e8ff53..4c8be548c5 100644
--- a/mindone/transformers/feature_extraction_utils.py
+++ b/mindone/transformers/feature_extraction_utils.py
@@ -25,7 +25,6 @@
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
-
from transformers.utils import (
cached_file,
download_url,
diff --git a/mindone/transformers/image_processing_utils.py b/mindone/transformers/image_processing_utils.py
index f743b8e457..194c99bf9a 100644
--- a/mindone/transformers/image_processing_utils.py
+++ b/mindone/transformers/image_processing_utils.py
@@ -23,7 +23,6 @@
import numpy as np
import requests
-
from transformers.utils import cached_file, download_url, is_offline_mode, is_remote_url, is_vision_available, logging
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py
index fab986a5a7..87858e2e60 100644
--- a/mindone/transformers/image_utils.py
+++ b/mindone/transformers/image_utils.py
@@ -1,11 +1,6 @@
-import base64
-import os
-from io import BytesIO
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
-import requests
-from packaging import version
from .utils.generic import ExplicitEnum
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
index 1aad29b744..96d60720c1 100644
--- a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -9,18 +9,15 @@
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
ChannelDimension,
- ImageInput,
infer_channel_dimension_format,
- is_batched,
is_torch_tensor,
- make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import ops
from ...image_processing_utils import BaseImageProcessor, BatchFeature
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
index 92beb0c126..52c1f67ea2 100644
--- a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -15,6 +15,7 @@
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmv import MiniCPMVProcessor
from .resampler import Resampler
+
# from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
index b6ebf327a7..2abd16cd80 100644
--- a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -22,14 +22,13 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union
-import numpy as np
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
import mindspore as ms
# import torch.utils.checkpoint
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import nn, ops
from mindspore.ops.operations.nn_ops import FlashAttentionScore as FlashAttention
from ...activations import ACT2FN
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
index 839a671f25..08df790475 100644
--- a/mindone/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -19,14 +19,13 @@
# limitations under the License.
"""Mindspore Qwen2 model."""
-import math
from typing import List, Optional, Tuple, Union
import numpy as np
from transformers import logging
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Parameter, nn, ops
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
From 1dde75a56e8b8178b74b981f6b3ee9d7b051c373 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Tue, 19 Nov 2024 20:54:55 +0800
Subject: [PATCH 4/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
examples/minicpm_v/finetune/dataset.py | 647 ++++++++++++++++++
examples/minicpm_v/finetune/finetune.py | 423 ++++++++++++
examples/minicpm_v/finetune/finetune.sh | 43 ++
examples/minicpm_v/finetune/finetune_8p.sh | 46 ++
.../transformers/models/minicpm_v/__init__.py | 2 +-
.../models/minicpm_v/configuration_minicpm.py | 3 +-
.../minicpm_v/image_processing_minicpmv.py | 84 ++-
.../models/minicpm_v/modeling_minicpmv.py | 150 ++--
.../models/minicpm_v/modeling_navit_siglip.py | 70 +-
.../models/minicpm_v/processing_minicpmv.py | 103 ++-
.../models/minicpm_v/resampler.py | 357 ++++------
.../minicpm_v/tokenization_minicpmv_fast.py | 5 +-
mindone/transformers/models/qwen2/__init__.py | 1 +
.../models/qwen2/modeling_qwen2.py | 57 +-
.../models/qwen2/tokenization_qwen2.py | 4 +-
.../models/qwen2/tokenization_qwen2_fast.py | 4 +-
16 files changed, 1562 insertions(+), 437 deletions(-)
create mode 100644 examples/minicpm_v/finetune/dataset.py
create mode 100644 examples/minicpm_v/finetune/finetune.py
create mode 100644 examples/minicpm_v/finetune/finetune.sh
create mode 100644 examples/minicpm_v/finetune/finetune_8p.sh
diff --git a/examples/minicpm_v/finetune/dataset.py b/examples/minicpm_v/finetune/dataset.py
new file mode 100644
index 0000000000..fa5a3b312c
--- /dev/null
+++ b/examples/minicpm_v/finetune/dataset.py
@@ -0,0 +1,647 @@
+import copy
+import json
+import logging
+import math
+import os
+import random
+import re
+import sys
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional
+
+import numpy as np
+from datasets import load_dataset
+from PIL import Image
+from transformers import AutoTokenizer
+
+import mindspore as ms
+from mindspore import ops
+
+# from torch.nn.utils.rnn import pad_sequence
+from mindspore.dataset import Dataset
+
+mindone_lib_path = os.path.abspath(os.path.abspath("../../../"))
+sys.path.insert(0, mindone_lib_path)
+
+import logging
+
+from mindone.transformers.models.minicpm_v2_6.processing_minicpmv import MiniCPMVProcessor
+
+logger = logging.getLogger(__name__)
+
+llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
+
+class SupervisedDataset:
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(
+ self,
+ raw_data,
+ transform,
+ tokenizer,
+ slice_config,
+ llm_type="minicpm",
+ patch_size=14,
+ query_nums=64,
+ batch_vision=False,
+ max_length=2048,
+ ):
+ super(SupervisedDataset, self).__init__()
+ self.raw_data = raw_data
+ self.tokenizer = tokenizer
+ self.transform = transform
+ self.slice_config = slice_config
+ self.llm_type = llm_type
+ self.patch_size = patch_size
+ self.query_nums=query_nums
+ self.batch_vision = batch_vision
+ self.max_length = max_length
+ # self.dataset_column_names = ["input_ids", "position_ids", "labels", "attention_mask", "pixel_values", "tgt_sizes", "image_bound"]
+ # self.dataset_output_column_names = ["input_ids", "position_ids", "labels", "attention_mask", "pixel_values", "tgt_sizes", "image_bound"]
+ self.dataset_column_names = ["item"]
+
+ def __len__(self):
+ return len(self.raw_data)
+
+ def __getitem__(self, idx, retry_count=3):
+ try:
+ if isinstance(self.raw_data[idx]["image"], str):
+ images_dict = { "" : Image.open(self.raw_data[idx]["image"]).convert("RGB") }
+ elif isinstance(self.raw_data[idx]["image"], Dict):
+ ### for multi-images input, the template for every image is , such as ,
+ images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[idx]["image"].items()}
+
+ ret = preprocess(
+ images_dict,
+ self.raw_data[idx]["conversations"],
+ self.tokenizer,
+ self.transform,
+ query_nums=self.query_nums,
+ slice_config=self.slice_config,
+ llm_type=self.llm_type,
+ patch_size=self.patch_size,
+ batch_vision=self.batch_vision,
+ max_length=self.max_length
+ )
+ ret = dict(
+ input_ids=ret["input_ids"],
+ position_ids=ret["position_ids"],
+ labels=ret["target"],
+ attention_mask=np.ones_like(ret["input_ids"], dtype=np.bool_),
+ pixel_values=ret["pixel_values"],
+ tgt_sizes=ret["tgt_sizes"],
+ image_bound=ret["image_bound"],
+ )
+
+ ret = data_collator(ret, max_length = self.max_length)
+
+ except (EOFError, ValueError, OSError) as e:
+ # Log and handle EOFError and other file-related errors
+ logger.error(f"Data fetch error at index {idx}: {e}")
+
+ if retry_count > 0:
+ logger.info(f"Retrying with a different sample. {retry_count} retries left.")
+ retry_idx = random.randint(0, len(self) - 1)
+ return self.__getitem__(retry_idx, retry_count - 1)
+ else:
+ # If max retries reached, return a blank or default item
+ logger.warning("Max retries reached. Returning a blank entry.")
+ return None
+
+ # except:
+ # logger.error(f"data fetch error")
+ # # return self.__getitem__(random.randint(0, len(self)))
+ # return (ret["input_ids"], ret["position_ids"], ret["labels"], np.ones_like(ret["input_ids"], dtype=np.bool_), ret["pixel_values"], ret["tgt_sizes"], ret["image_bound"])
+ return ret
+
+def data_collator(examples, padding_value=0, max_length=2048):
+ def trim_and_pad(seq, batch_first, padding_value):
+ # return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
+ # return np.stack([s[:max_length] for s in seq])
+ return seq
+
+ input_ids = trim_and_pad(
+ examples["input_ids"],
+ batch_first=True,
+ padding_value=padding_value,
+ )
+ position_ids = trim_and_pad(
+ examples["position_ids"],
+ batch_first=True,
+ padding_value=padding_value,
+ )
+ targets = trim_and_pad(
+ examples["labels"],
+ batch_first=True,
+ padding_value=-100,
+ )
+ attention_mask = trim_and_pad(
+ examples["attention_mask"],
+ batch_first=True,
+ padding_value=padding_value,
+ )
+ pixel_values = examples["pixel_values"]
+ image_bound = examples["image_bound"]
+ tgt_sizes = examples["tgt_sizes"]
+ # return {
+ # "input_ids": input_ids,
+ # "position_ids": position_ids,
+ # "labels": targets,
+ # "attention_mask": attention_mask,
+ # "image_bound": image_bound,
+ # "tgt_sizes": tgt_sizes,
+ # "pixel_values": pixel_values,
+ # }
+ outputs = dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ labels=targets,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ tgt_sizes=tgt_sizes,
+ image_bound=image_bound,
+ )
+
+ return outputs
+
+
+def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048):
+ """
+ for single image multi-turn conversation
+ conversation: [{'role': 'user', 'content': 'Describe this image'},
+ {'role': 'assistant', 'content': 'This is a cat.'}]
+ """
+ if llm_type == "llama3":
+ input_ids, context, raw_msg = conversation_to_ids_llama3(
+ conversation, tokenizer
+ )
+ elif llm_type == "qwen2":
+ input_ids, context, raw_msg = conversation_to_ids_qwen2(
+ conversation, tokenizer
+ )
+ else:
+ input_ids, context, raw_msg = conversation_to_ids_minicpm(
+ conversation, tokenizer
+ )
+
+ ids = np.hstack(input_ids, dtype=np.int32)
+ context = np.hstack(context, dtype=np.int8)
+ if input_ids.shape[-1] > max_length:
+ ids = ids[:max_length]
+ context = context[:max_length]
+ logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
+
+ if np.all(context):
+ logger.error("No tokens available to compute loss.")
+ raise Exception("No tokens available to compute loss.")
+
+ # build target
+ target = np.full_like(ids, -100, dtype=np.int32)
+
+ for i in range(1, len(ids)):
+ if context[i] == 0:
+ target[i - 1] = ids[i]
+ if context[i] == 1 and context[i - 1] == 0:
+ if hasattr(tokenizer, "eot_id"):
+ target[i - 1] = tokenizer.eot_id
+ else:
+ target[i - 1] = tokenizer.eos_id
+
+ # build image bound
+ if new_schema:
+ start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
+ end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
+ image_start_tokens = np.where(start_cond)[0]
+ image_start_tokens += 1
+ image_end_tokens = np.where(end_cond)[0]
+ else:
+ image_start_tokens = np.where(ids == tokenizer.im_start_id)[0]
+ image_start_tokens += 1
+ image_end_tokens = np.where(ids == tokenizer.im_end_id)[0]
+ if len(image_start_tokens) != len(image_end_tokens):
+ logger.error("image start token != image end tokens")
+ raise Exception("image start token != image end tokens")
+
+ if len(image_start_tokens) > 0:
+ image_bound = np.hstack(
+ [np.expand_dims(image_start_tokens, axis=-1), np.expand_dims(image_end_tokens, axis=-1)]
+ )
+ else:
+ image_bound = []
+
+ position_ids = np.arange(ids.shape[0]).astype(np.int64)
+ return {
+ "input_ids": ids,
+ "target": target,
+ "image_bound": image_bound,
+ "raw_msg": raw_msg,
+ "position_ids": position_ids
+ }
+
+
+def conversation_to_ids_minicpm(conversation, tokenizer):
+ raw_msg = ""
+ input_ids = []
+ context = []
+ for idx, msg in enumerate(conversation):
+ role = msg["role"]
+ message = msg["content"]
+ assert role in ["user", "assistant"]
+ if role == "user":
+ prefix = "<用户>"
+ else:
+ prefix = ""
+ # append eos
+ if idx == len(conversation) - 1:
+ message = message + tokenizer.eos_token
+ prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
+ message_ids = tokenizer.encode(message)[1:]
+
+ input_ids.append(prefix_ids)
+ input_ids.append(message_ids)
+
+ context.append(np.ones((len(prefix_ids),), dtype=np.int8))
+ if role == "assistant":
+ context.append(np.zeros((len(message_ids),), dtype=np.int8))
+ else:
+ context.append(np.ones((len(message_ids),), dtype=np.int8))
+
+ raw_msg += prefix + message
+
+ return input_ids, context, raw_msg
+
+
+def conversation_to_ids_llama3(conversation, tokenizer):
+ raw_msg = ""
+ input_ids = []
+ context = []
+ raw_msg = tokenizer.apply_chat_template(
+ conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template,
+ )
+ input_ids = tokenizer.apply_chat_template(
+ conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template,
+ )
+ input_ids = np.array(input_ids)
+
+ start_header_idxs = np.where(
+ input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
+ )[0]
+ assistant_idxs = np.where(
+ input_ids == tokenizer.convert_tokens_to_ids("assistant")
+ )[0]
+ end_header_idxs = np.where(
+ input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
+ )[0]
+ eot_idxs = np.where(
+ input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
+
+ context = np.ones_like(input_ids, dtype=np.int8)
+
+ for assistant_idx in assistant_idxs:
+ if assistant_idx in set((start_header_idxs + end_header_idxs) / 2):
+ st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
+ for eot_idx in eot_idxs:
+ if eot_idx > st:
+ context[st: eot_idx + 1] = 0
+ break
+
+ input_ids = np.hstack(input_ids)
+ context = np.hstack(context)
+
+ return input_ids, context, raw_msg
+
+
+def conversation_to_ids_qwen2(conversation, tokenizer):
+ raw_msg = ""
+ chat = []
+ context = []
+ for idx, msg in enumerate(conversation):
+ role = msg["role"]
+ message = msg["content"]
+ assert role in ["user", "assistant"]
+ if role == "user":
+ prefix = "user"
+ else:
+ prefix = "assistant"
+ chat.append({"role":prefix, "content":message})
+ raw_msg += prefix + message
+ assert set([i['role'] for i in chat]) & set(['assistant'])
+
+ ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
+ input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
+ input_ids = np.array(input_ids)
+
+ start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
+ assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
+ end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
+
+ context = np.ones_like(input_ids, dtype=np.int8)
+
+ for assistant_idx in assistant_idxs:
+ if assistant_idx-1 in set(start_idxs):
+ st = assistant_idx + 1
+ for end_idx in end_idxs:
+ if end_idx > st:
+ context[st: end_idx + 1] = 0
+ break
+
+ input_ids = np.hstack(input_ids)
+ context = np.hstack(context)
+ return input_ids, context, raw_msg
+
+def trans_fn(x):
+ x = np.asarray(x).transpose((2, 0, 1))
+ return (x-0.5*255)/(0.5*255)
+
+def preprocess(
+ images_dict,
+ conversations,
+ tokenizer,
+ transform,
+ query_nums=64,
+ slice_config=None,
+ llm_type=None,
+ patch_size=14,
+ batch_vision=False,
+ max_length=2048,
+):
+ """
+ single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
+ """
+ conversations = copy.deepcopy(conversations)
+ assert len(conversations) > 1, "conversations length must large than 2"
+ assert conversations[0]["role"] == "user", "the first role must be user"
+
+ if slice_config is not None:
+ assert isinstance(slice_config, Dict)
+ assert "patch_size" in slice_config
+ assert "max_slice_nums" in slice_config
+ assert "scale_resolution" in slice_config
+ default_image_placeholder = (
+ tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
+ )
+ new_schema = False
+ use_image_id = False
+ if llm_type=='qwen2':
+ new_schema = True
+ use_image_id = True
+ image_placeholder_dict = {}
+ images = []
+ image_id_cnt = 0
+ for img_name, image in images_dict.items():
+ if slice_config:
+ source_image, patches, best_grid = slice_image(
+ image,
+ slice_config["max_slice_nums"],
+ slice_config["scale_resolution"],
+ slice_config["patch_size"],
+ )
+ images.append(source_image)
+ image_placeholder = default_image_placeholder
+ if len(patches) > 0:
+ for i in range(len(patches)):
+ for j in range(len(patches[0])):
+ images.append(patches[i][j])
+ if use_image_id:
+ image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
+ image_id_cnt += 1
+ image_placeholder += get_grid_placeholder(
+ tokenizer, best_grid, query_nums, new_schema = new_schema)
+ image_placeholder_dict[img_name] = image_placeholder
+ else:
+ images.append(image)
+ if use_image_id:
+ image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
+ image_id_cnt += 1
+ else:
+ image_placeholder = default_image_placeholder
+ image_placeholder_dict[img_name] = image_placeholder
+
+ images = [trans_fn(i) for i in images]
+
+ if len(images_dict) == 1 and "" in images_dict:
+ if "" in conversations[0]["content"]:
+ conversations[0]["content"] = conversations[0]["content"].replace(
+ "", image_placeholder
+ )
+ else:
+ conversations[0]["content"] = (
+ image_placeholder + "\n" + conversations[0]["content"]
+ )
+ input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
+ else:
+ pattern = r''
+ new_conversations = []
+ for conversation in conversations:
+ content = conversation['content']
+ parts = re.split(f'({pattern})', content)
+ for i, part in enumerate(parts):
+ if not part.strip():
+ continue
+ if re.match(pattern, part):
+ if part in image_placeholder_dict:
+ parts[i] = image_placeholder_dict[part]
+ else:
+ raise Exception(f"not found {part} in image dict")
+ conversation['content'] = '\n'.join(parts)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
+
+ if batch_vision:
+ tgt_sizes = []
+ reshape_images = []
+ for image in images:
+ H, W = image.shape[1:]
+ reshape_image = reshape_by_patch(image, patch_size)
+ reshape_images.append(reshape_image)
+ tgt_sizes.append([H // patch_size, W // patch_size])
+ if tgt_sizes:
+ tgt_sizes = np.array(tgt_sizes).astype(np.int32)
+
+ input_dict["pixel_values"] = reshape_images
+ input_dict["tgt_sizes"] = tgt_sizes
+
+ else:
+ input_dict["pixel_values"] = images
+ input_dict["tgt_sizes"] = []
+
+ return input_dict
+
+
+def slice_image(
+ image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
+):
+ original_size = image.size
+ original_width, original_height = original_size
+ log_ratio = math.log(original_width / original_height)
+ ratio = original_width * original_height / \
+ (scale_resolution * scale_resolution)
+ multiple = min(math.ceil(ratio), max_slice_nums)
+
+ source_image = None
+ best_grid = None
+ patches = []
+
+ if multiple <= 1 or never_split:
+ # dont need to slice, upsample
+ best_size = find_best_resize(
+ original_size, scale_resolution, patch_size, allow_upscale=True
+ )
+ source_image = image.resize(best_size, Image.Resampling.BICUBIC)
+ else:
+ candidate_split_grids_nums = []
+ for i in [multiple - 1, multiple, multiple + 1]:
+ if i == 1 or i > max_slice_nums:
+ continue
+ candidate_split_grids_nums.append(i)
+
+ # source image, down-sampling and ensure divided by patch_size
+ best_resize = find_best_resize(
+ original_size, scale_resolution, patch_size)
+ source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
+ candidate_grids = []
+
+ # find best grid
+ for split_grids_nums in candidate_split_grids_nums:
+ m = 1
+ while m <= split_grids_nums:
+ if split_grids_nums % m == 0:
+ candidate_grids.append([m, split_grids_nums // m])
+ m += 1
+
+ best_grid = [1, 1]
+ min_error = float("inf")
+ for grid in candidate_grids:
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
+ if error < min_error:
+ best_grid = grid
+ min_error = error
+
+ refine_size = get_refine_size(
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
+ )
+
+ refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
+ patches = split_to_patches(refine_image, best_grid)
+
+ return source_image, patches, best_grid
+
+
+def ensure_divide(length, patch_size):
+ return max(round(length / patch_size) * patch_size, patch_size)
+
+
+def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
+ width, height = original_size
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
+ r = width / height
+ height = int(scale_resolution / math.sqrt(r))
+ width = int(height * r)
+ best_width = ensure_divide(width, patch_size)
+ best_height = ensure_divide(height, patch_size)
+ return (best_width, best_height)
+
+
+def get_refine_size(
+ original_size, grid, scale_resolution, patch_size, allow_upscale=False
+):
+ width, height = original_size
+ grid_x, grid_y = grid
+
+ refine_width = ensure_divide(width, grid_x)
+ refine_height = ensure_divide(height, grid_y)
+
+ grid_width = refine_width / grid_x
+ grid_height = refine_height / grid_y
+
+ best_grid_size = find_best_resize(
+ (grid_width, grid_height),
+ scale_resolution,
+ patch_size,
+ allow_upscale=allow_upscale,
+ )
+
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
+
+ return refine_size
+
+
+def split_to_patches(image, grid):
+ patches = []
+ width, height = image.size
+ grid_x = int(width / grid[0])
+ grid_y = int(height / grid[1])
+
+ for i in range(0, height, grid_y):
+ images = []
+ for j in range(0, width, grid_x):
+ box = (j, i, j + grid_x, i + grid_y)
+ patch = image.crop(box)
+ images.append(patch)
+ patches.append(images)
+
+ return patches
+
+
+def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
+ if new_schema:
+ image_placeholder = (
+ tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
+ )
+ else:
+ image_placeholder = (
+ tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
+ )
+
+ cols = grid[0]
+ rows = grid[1]
+ slices = []
+ for i in range(rows):
+ lines = []
+ for j in range(cols):
+ lines.append(image_placeholder)
+ slices.append("".join(lines))
+ if new_schema:
+ slice_placeholder = '\n'.join(slices)
+ else:
+ slice_placeholder = tokenizer.slice_start + \
+ "\n".join(slices) + tokenizer.slice_end
+ return slice_placeholder
+
+
+def reshape_by_patch(image_tensor, patch_size):
+ """
+ :param image_tensor: shape [3, H, W]
+ :param patch_size:
+ :return: [3, patch_size, HW/patch_size]
+ """
+ # image = ms.Tensor(image_tensor)
+ #
+ # c = image.shape[0]
+ # h = image.shape[1]
+ # w = image.shape[2]
+ # image = image.reshape(1, c, h, w)
+ #
+ # patches = ops.unfold(
+ # image,
+ # (patch_size, patch_size),
+ # stride=(patch_size, patch_size)
+ # )
+
+ c = image_tensor.shape[0]
+ h = image_tensor.shape[1]
+ w = image_tensor.shape[2]
+
+ v_block_num = h // patch_size
+ h_block_num = w // patch_size
+
+ patches = image_tensor.reshape(c, v_block_num, patch_size, h_block_num, patch_size)
+ patches = np.transpose(patches, (0, 2, 4, 1, 3))
+ patches = patches.reshape(c*patch_size*patch_size, -1)
+
+ patches = patches.reshape(image_tensor.shape[0], patch_size, patch_size, -1)
+ patches = patches.transpose((0, 1, 3, 2)).reshape(
+ image_tensor.shape[0], patch_size, -1)
+ return patches
diff --git a/examples/minicpm_v/finetune/finetune.py b/examples/minicpm_v/finetune/finetune.py
new file mode 100644
index 0000000000..50140ca723
--- /dev/null
+++ b/examples/minicpm_v/finetune/finetune.py
@@ -0,0 +1,423 @@
+import glob
+import json
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from functools import partial
+from types import MethodType
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import numpy as np
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, context, dataset, nn, ops
+from mindspore.communication.management import get_group_size, get_rank, init
+from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list
+
+# init()
+# rank, rank_size, parallel_mode = get_rank(), get_group_size(), context.ParallelMode.DATA_PARALLEL
+# context.set_auto_parallel_context(
+# device_num=rank_size, parallel_mode=parallel_mode, gradients_mean=True
+# )
+
+rank, rank_size = 0, 1
+
+ms.set_context(mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True, mempool_block_size="59GB", max_device_memory="59GB")
+
+import transformers
+from transformers import HfArgumentParser
+
+from mindspore.dataset import transforms, vision
+
+# from accelerate.utils import DistributedType
+
+mindone_lib_path = os.path.abspath(os.path.abspath("../../../"))
+sys.path.insert(0, mindone_lib_path)
+
+from dataset import SupervisedDataset
+from mindone.transformers.trainer import Trainer
+from transformers import AutoTokenizer
+from mindone.transformers.training_args import TrainingArguments
+
+from mindone.transformers.models.minicpm_v2_6 import MiniCPMV_v2_6
+from mindone.transformers.mindspore_adapter import MindSporeArguments
+
+# from transformers.integrations import deepspeed
+
+
+# from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+
+# ms.set_context(mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True)
+# ms.set_context(mode=ms.context.PYNATIVE_MODE)
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
+ eval_data_path: str = field(
+ default=None, metadata={"help": "Path to the evaluation data."}
+ )
+
+# @dataclass
+# class TrainingArguments(TrainingArguments):
+# cache_dir: Optional[str] = field(default=None)
+# optim: str = field(default="adamw_mindspore")
+# model_max_length: int = field(
+# default=2048,
+# metadata={
+# "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+# },
+# )
+# tune_vision: Optional[bool] = field(default=True)
+# tune_llm: Optional[bool] = field(default=True)
+# llm_type: str = field(default="minicpm")
+# use_lora: Optional[bool] = field(default=False)
+# max_slice_nums: Optional[int] = field(default=9)
+# distributed: Optional[bool] = field(default=False)
+# amp_level: Optional[str] = field(default="O0")
+
+
+@dataclass
+class LoraArguments:
+ lora_r: int = 64
+ lora_alpha: int = 64
+ lora_dropout: float = 0.05
+ lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ q_lora: bool = False
+ lora_modules_to_save: str = ""
+ lora_layer_replication: Optional[List[Tuple[int, int]]] = None
+ lora_layers_to_transform: Optional[List[int]] = None
+ lora_layers_pattern: Optional[str] = None
+
+@dataclass
+class MyArguments(MindSporeArguments, TrainingArguments):
+ enable_flash_attention: bool = field(default=False)
+ gradient_checkpointing: bool = field(default=False)
+ is_distribute: bool = field(default=False)
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_mindspore")
+ model_max_length: int = field(
+ default=2048,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ tune_vision: Optional[bool] = field(default=True)
+ tune_llm: Optional[bool] = field(default=True)
+ llm_type: str = field(default="minicpm")
+ use_lora: Optional[bool] = field(default=False)
+ max_slice_nums: Optional[int] = field(default=9)
+ distributed: Optional[bool] = field(default=False)
+ amp_level: Optional[str] = field(default="O0")
+
+local_rank = None
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
+ """Collects the state dict and dump to disk."""
+ if trainer.args.should_save and trainer.args.local_rank == 0:
+ trainer.save_model(output_dir,)
+
+# class ModifiedMapFunction(BaseMapFuction):
+# def __call__(self, input_ids, position_ids, labels, attention_mask):
+# return trim_and_pad(input_ids), trim_and_pad(position_ids), trim_and_pad(labels), trim_and_pad(attention_mask)
+
+
+def make_supervised_data_module(
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args,
+ transform,
+ data_collator=None,
+ llm_type="minicpm",
+ slice_config=None,
+ patch_size=14,
+ query_nums=64,
+ batch_vision=False,
+ max_length=2048,
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+
+ dataset_cls = SupervisedDataset
+
+ rank0_print("Loading data...")
+
+ train_json = json.load(open(data_args.data_path, "r"))
+ train_dataset = dataset_cls(
+ train_json,
+ transform,
+ tokenizer,
+ slice_config=slice_config,
+ llm_type=llm_type,
+ patch_size=patch_size,
+ query_nums=query_nums,
+ batch_vision=batch_vision,
+ max_length=max_length,
+ )
+
+ # train_ds = dataset.GeneratorDataset(
+ # train_dataset,
+ # column_names=train_dataset.dataset_column_names,
+ # num_parallel_workers=2,
+ # shuffle=True,
+ # python_multiprocessing=False,
+ # num_shards=rank_size,
+ # shard_id=rank
+ # )
+
+ if data_args.eval_data_path:
+ eval_json = json.load(open(data_args.eval_data_path, "r"))
+ eval_dataset = dataset_cls(
+ eval_json,
+ transform,
+ tokenizer,
+ slice_config=slice_config,
+ llm_type=llm_type,
+ patch_size=patch_size,
+ query_nums=query_nums,
+ batch_vision=batch_vision,
+ max_length=max_length,
+ )
+
+ # eval_ds = dataset.GeneratorDataset(
+ # eval_dataset,
+ # column_names=eval_dataset.dataset_column_names,
+ # num_parallel_workers=8,
+ # shuffle=False,
+ # python_multiprocessing=False,
+ # )
+ else:
+ eval_dataset = None
+
+ # def trim_and_pad(seq):
+ # # return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
+ # max_length = 2048
+ # return np.stack([s[:max_length] for s in seq])
+ #
+ # class ModifiedMapFunction(BaseMapFuction):
+ # def __call__(self, input_ids, position_ids, labels, attention_mask):
+ # return trim_and_pad(input_ids), trim_and_pad(position_ids), trim_and_pad(labels), trim_and_pad(attention_mask)
+
+ return dict(
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+
+
+# def build_transform():
+# IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
+# IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
+# return transforms.Compose(
+# [
+# vision.ToTensor(),
+# vision.Normalize(
+# mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False
+# ),
+# ]
+# )
+
+def build_transform():
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
+ return transforms.Compose(
+ [
+ vision.Normalize(
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False
+ ),
+ ]
+ )
+
+def get_parameter_number(model):
+ trainable_params, all_param = 0, 0
+ # for param in model.parameters():
+ # num_params = param.numel()
+ # # if using DS Zero 3 and the weights are initialized empty
+ # if num_params == 0 and hasattr(param, "ds_numel"):
+ # num_params = param.ds_numel
+ #
+ # all_param += num_params
+ # if param.requires_grad:
+ # trainable_params += num_params
+ for param in model.trainable_params():
+ num_params = np.prod(param.shape)
+ trainable_params += num_params
+
+ return {'Trainable params': trainable_params}
+
+
+local_rank = 0
+
+
+def train():
+ global local_rank
+ parser = HfArgumentParser(
+ (ModelArguments, DataArguments, MyArguments, LoraArguments)
+ )
+
+ (
+ model_args,
+ data_args,
+ training_args,
+ lora_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ # if getattr(training_args, "deepspeed", None) :
+ # training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
+
+ compute_dtype = (
+ ms.float16
+ if training_args.fp16
+ else (ms.bfloat16 if training_args.bf16 else ms.float32)
+ )
+
+ # if training_args.distributed:
+ # init()
+ # data_args.rank, data_args.rank_size, parallel_mode = get_rank(), get_group_size(), context.ParallelMode.DATA_PARALLEL
+ # context.set_auto_parallel_context(
+ # device_num=data_args.rank_size, parallel_mode=parallel_mode, gradients_mean=True
+ # )
+ # else:
+ # data_args.rank, data_args.rank_size, parallel_mode = 0, 1, None
+
+ local_rank = training_args.local_rank
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ ddp = world_size != 1
+ device_map = None
+ if lora_args.q_lora:
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
+ if len(training_args.fsdp) > 0:
+ logging.warning(
+ "FSDP or ZeRO3 are not incompatible with QLoRA."
+ )
+
+ model = MiniCPMV_v2_6.from_pretrained(
+ model_args.model_name_or_path,
+ trust_remote_code=True,
+ mindspore_dtype=compute_dtype,
+ )
+
+ if training_args.amp_level == "O2":
+ _auto_black_list(
+ model,
+ AMP_BLACK_LIST + [nn.GroupNorm, nn.SiLU],
+ ms.float16,
+ )
+ elif training_args.amp_level == "O3":
+ model.to_float(ms.float16)
+
+ # if training_args.distributed:
+ # # set grad reducer
+ # mean = ms.context.get_auto_parallel_context("gradients_mean")
+ # degree = ms.context.get_auto_parallel_context("device_num")
+ # grad_reducer = nn.DistributedGradReducer(model.trainable_params(), mean, degree)
+ # else:
+ # grad_reducer = None
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path, trust_remote_code=True
+ )
+
+ if not training_args.tune_vision:
+ # model.vpm.set_train(False)
+ for param in model.vpm.trainable_params():
+ param.requires_grad = False
+ if not training_args.tune_llm:
+ # model.llm.set_train(False)
+ for param in model.llm.trainable_params():
+ param.requires_grad = False
+
+ if training_args.use_lora:
+ if training_args.use_lora and training_args.tune_llm:
+ raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
+
+ rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
+ for name, param in model.llm.named_parameters():
+ param.requires_grad = False
+ modules_to_save = ['embed_tokens','resampler']
+ if training_args.tune_vision:
+ modules_to_save.append('vpm')
+ lora_config = LoraConfig(
+ r=lora_args.lora_r,
+ lora_alpha=lora_args.lora_alpha,
+ target_modules=lora_args.lora_target_modules,
+ lora_dropout=lora_args.lora_dropout,
+ bias=lora_args.lora_bias,
+ layers_to_transform=lora_args.lora_layers_to_transform,
+ modules_to_save=modules_to_save,
+ )
+ if not hasattr(model, 'get_input_embeddings'):
+ def get_input_embeddings(self):
+ return self.llm.get_input_embeddings()
+ model.get_input_embeddings = MethodType(get_input_embeddings, model)
+ if lora_args.q_lora:
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
+ )
+ model = get_peft_model(model, lora_config)
+ if training_args.gradient_checkpointing:
+ model.enable_input_require_grads()
+
+ rank0_print(get_parameter_number(model))
+
+ llm_type = training_args.llm_type
+
+ rank0_print(f'llm_type={llm_type}')
+
+
+ # Load data
+ if hasattr(model.config, "slice_config"):
+ model.config.slice_config.max_slice_nums = training_args.max_slice_nums
+ slice_config = model.config.slice_config.to_dict()
+ else:
+ model.config.max_slice_nums = training_args.max_slice_nums
+ slice_config = model.config.to_dict()
+
+ if hasattr(model.config, "batch_vision_input"):
+ batch_vision = model.config.batch_vision_input
+ else:
+ batch_vision = False
+
+ transform_func = build_transform()
+ data_module = make_supervised_data_module(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ transform=transform_func,
+ data_collator=None,
+ slice_config=slice_config,
+ llm_type=llm_type,
+ patch_size=model.config.patch_size,
+ query_nums=model.config.query_num,
+ batch_vision=batch_vision,
+ max_length=training_args.model_max_length,
+ )
+
+ training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
+ trainer = Trainer(
+ model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module,
+ )
+
+ trainer.train()
+ # trainer.save_state()
+
+ safe_save_model_for_hf_trainer(
+ trainer=trainer,
+ output_dir=training_args.output_dir,
+ bias=lora_args.lora_bias)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/examples/minicpm_v/finetune/finetune.sh b/examples/minicpm_v/finetune/finetune.sh
new file mode 100644
index 0000000000..63c6acac10
--- /dev/null
+++ b/examples/minicpm_v/finetune/finetune.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+MODEL="openbmb/MiniCPM-V-2_6"
+# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
+# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
+# See the section for finetuning in README for more information.
+DATA="/data3/wcr/mindone/examples/minicpm/finetune/finetune.json"
+#EVAL_DATA="path/to/test_data"
+LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
+MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
+
+python finetune.py \
+ --model_name_or_path $MODEL \
+ --llm_type $LLM_TYPE \
+ --data_path $DATA \
+ --remove_unused_columns false \
+ --label_names "labels" \
+ --prediction_loss_only false \
+ --bf16 false \
+ --bf16_full_eval false \
+ --fp16 false \
+ --fp16_full_eval false \
+ --do_train \
+ --tune_vision true \
+ --tune_llm false \
+ --model_max_length $MODEL_MAX_Length \
+ --max_slice_nums 9 \
+ --max_steps 10000 \
+ --output_dir output/output_minicpmv26 \
+ --logging_dir output/output_minicpmv26 \
+ --logging_strategy "steps" \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --save_strategy "steps" \
+ --save_steps 1000 \
+ --save_total_limit 10 \
+ --learning_rate 1e-6 \
+ --weight_decay 0.1 \
+ --adam_beta2 0.95 \
+ --warmup_ratio 0.01 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ > pynative_logs/train_vision.log 2>&1 &
diff --git a/examples/minicpm_v/finetune/finetune_8p.sh b/examples/minicpm_v/finetune/finetune_8p.sh
new file mode 100644
index 0000000000..553133e94c
--- /dev/null
+++ b/examples/minicpm_v/finetune/finetune_8p.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+MODEL="openbmb/MiniCPM-V-2_6"
+# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
+# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
+# See the section for finetuning in README for more information.
+DATA="/data3/wcr/mindone/examples/minicpm/finetune/finetune.json"
+#EVAL_DATA="path/to/test_data"
+LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
+MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
+
+export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7
+
+msrun --worker_num=4 --local_worker_num=4 --master_port=8118 --log_dir=pynative_logs --join=True --cluster_time_out=300 finetune.py \
+ --model_name_or_path $MODEL \
+ --llm_type $LLM_TYPE \
+ --data_path $DATA \
+ --remove_unused_columns false \
+ --label_names "labels" \
+ --prediction_loss_only false \
+ --bf16 false \
+ --bf16_full_eval false \
+ --fp16 false \
+ --fp16_full_eval false \
+ --do_train \
+ --tune_vision true \
+ --tune_llm false \
+ --model_max_length $MODEL_MAX_Length \
+ --max_slice_nums 9 \
+ --max_steps 10000 \
+ --output_dir output/output_minicpmv26 \
+ --logging_dir output/output_minicpmv26 \
+ --logging_strategy "steps" \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --save_strategy "steps" \
+ --save_steps 1000 \
+ --save_total_limit 10 \
+ --learning_rate 1e-6 \
+ --weight_decay 0.1 \
+ --adam_beta2 0.95 \
+ --warmup_ratio 0.01 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --distributed true \
+ > pynative_logs/train_vision.log 2>&1 &
diff --git a/mindone/transformers/models/minicpm_v/__init__.py b/mindone/transformers/models/minicpm_v/__init__.py
index 3d30bf0084..d9273d2cf8 100644
--- a/mindone/transformers/models/minicpm_v/__init__.py
+++ b/mindone/transformers/models/minicpm_v/__init__.py
@@ -1,2 +1,2 @@
-from .image_processing_minicpmv import MiniCPMVImageProcessor
from .modeling_minicpmv import MiniCPMV_v2_6
+from .image_processing_minicpmv import MiniCPMVImageProcessor
diff --git a/mindone/transformers/models/minicpm_v/configuration_minicpm.py b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
index db1f383fc0..063cfee91b 100644
--- a/mindone/transformers/models/minicpm_v/configuration_minicpm.py
+++ b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
@@ -45,6 +45,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
return cls.from_dict(config_dict, **kwargs)
+
class MiniCPMVConfig(Qwen2Config):
model_type = "minicpmv"
keys_to_ignore_at_inference = ["past_key_values"]
@@ -57,7 +58,7 @@ class MiniCPMVConfig(Qwen2Config):
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
- "attn_implementation": "flash_attention",
+ "attn_implementation": "flash_attention"
}
def __init__(
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
index 96d60720c1..7626964f1b 100644
--- a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -5,21 +5,23 @@
import PIL
import PIL.Image
import PIL.ImageSequence
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
from PIL import Image
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
ChannelDimension,
+ ImageInput,
infer_channel_dimension_format,
+ is_batched,
is_torch_tensor,
+ make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
import mindspore as ms
-from mindspore import ops
-
-from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from mindspore import Parameter, Tensor, nn, ops
def recursive_converter(converter, value):
@@ -102,7 +104,12 @@ def cast_tensor(v):
class MiniCPMVImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
- def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
+ def __init__(
+ self,
+ max_slice_nums=9,
+ scale_resolution=448,
+ patch_size=14,
+ **kwargs):
super().__init__(**kwargs)
self.max_slice_nums = max_slice_nums
self.scale_resolution = scale_resolution
@@ -124,9 +131,14 @@ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwar
def ensure_divide(self, length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
- def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
+ def find_best_resize(self,
+ original_size,
+ scale_resolution,
+ patch_size,
+ allow_upscale=False):
width, height = original_size
- if (width * height > scale_resolution * scale_resolution) or allow_upscale:
+ if (width * height >
+ scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
@@ -134,7 +146,12 @@ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_up
best_height = self.ensure_divide(height, patch_size)
return (best_width, best_height)
- def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
+ def get_refine_size(self,
+ original_size,
+ grid,
+ scale_resolution,
+ patch_size,
+ allow_upscale=False):
width, height = original_size
grid_x, grid_y = grid
@@ -144,9 +161,10 @@ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, all
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
- best_grid_size = self.find_best_resize(
- (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
- )
+ best_grid_size = self.find_best_resize((grid_width, grid_height),
+ scale_resolution,
+ patch_size,
+ allow_upscale=allow_upscale)
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
@@ -164,7 +182,9 @@ def split_to_patches(self, image, grid):
patches.append(images)
return patches
- def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
+ def slice_image(
+ self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
+ ):
original_size = image.size
source_image = None
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
@@ -172,7 +192,9 @@ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=
if best_grid is None:
# dont need to slice, upsample
- best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
+ best_size = self.find_best_resize(
+ original_size, scale_resolution, patch_size, allow_upscale=True
+ )
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
else:
# source image, down-sampling and ensure divided by patch_size
@@ -190,7 +212,9 @@ def get_grid_placeholder(self, grid):
if grid is None:
return ""
slice_image_placeholder = (
- self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
+ self.slice_start_token
+ + self.unk_token * self.image_feature_size
+ + self.slice_end_token
)
cols = grid[0]
@@ -217,7 +241,10 @@ def get_sliced_images(self, image, max_slice_nums=None):
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
assert max_slice_nums > 0
source_image, patches, sliced_grid = self.slice_image(
- image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
+ image,
+ max_slice_nums, # default: 9
+ self.scale_resolution, # default: 448
+ self.patch_size # default: 14
)
slice_images.append(source_image)
@@ -263,7 +290,11 @@ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=No
assert max_slice_nums > 0
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
- image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
+ image_placeholder = (
+ self.im_start_token
+ + self.unk_token * self.image_feature_size
+ + self.im_end_token
+ )
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
if use_image_id:
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
@@ -318,7 +349,11 @@ def reshape_by_patch(self, image):
w = image.shape[2]
image = image.reshape(1, c, h, w)
- patches = ops.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
+ patches = ops.unfold(
+ image,
+ (patch_size, patch_size),
+ stride=(patch_size, patch_size)
+ )
image = image.squeeze(axis=0)
@@ -327,12 +362,12 @@ def reshape_by_patch(self, image):
return patches.numpy()
def preprocess(
- self,
- images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
- do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
- max_slice_nums: int = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs,
+ self,
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
+ max_slice_nums: int = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
) -> MiniCPMVBatchFeature:
if isinstance(images, Image.Image):
images_list = [[images]]
@@ -377,8 +412,7 @@ def preprocess(
for slice_image in image_patches:
new_images.append(self.reshape_by_patch(slice_image))
tgt_sizes.append(
- np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
- )
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
if tgt_sizes:
tgt_sizes = np.vstack(tgt_sizes)
@@ -388,7 +422,7 @@ def preprocess(
tgt_sizes_list.append(tgt_sizes)
return MiniCPMVBatchFeature(
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
- tensor_type=return_tensors,
+ tensor_type=return_tensors
)
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
index 52c1f67ea2..003333ddcd 100644
--- a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -2,22 +2,23 @@
import math
from copy import deepcopy
from threading import Thread
+from typing import List, Optional
-from PIL import Image
from transformers import TextIteratorStreamer
+from PIL import Image
import mindspore as ms
-from mindspore import Parameter, Tensor, _no_grad, nn, ops
+from mindspore import Parameter, Tensor, nn, ops
from ..qwen2 import Qwen2ForCausalLM, Qwen2PreTrainedModel
from .configuration_minicpm import MiniCPMVConfig
-from .image_processing_minicpmv import MiniCPMVImageProcessor
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmv import MiniCPMVProcessor
+from .image_processing_minicpmv import MiniCPMVImageProcessor
from .resampler import Resampler
-
# from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
+from mindspore import _no_grad
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
@@ -33,21 +34,21 @@ def __init__(self, config):
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.processor = None
- self.terminators = ["<|im_end|>", "<|endoftext|>"]
+ self.terminators = ['<|im_end|>', '<|endoftext|>']
def init_vision_module(self):
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
- if self.config._attn_implementation == "flash_attention_2":
- self.config.vision_config._attn_implementation = "flash_attention_2"
+ if self.config._attn_implementation == 'flash_attention_2':
+ self.config.vision_config._attn_implementation = 'flash_attention_2'
else:
# not suport sdpa
- self.config.vision_config._attn_implementation = "eager"
+ self.config.vision_config._attn_implementation = 'eager'
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
- setattr(model, "embed_dim", model.embeddings.embed_dim)
- setattr(model, "patch_size", model.embeddings.patch_size)
+ setattr(model, 'embed_dim', model.embeddings.embed_dim)
+ setattr(model, 'patch_size', model.embeddings.patch_size)
return model
@@ -57,7 +58,7 @@ def init_resampler(self, embed_dim, vision_dim):
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
- adaptive=True,
+ adaptive=True
)
def get_input_embeddings(self):
@@ -79,11 +80,11 @@ def get_decoder(self):
return self.llm
def get_vllm_embedding(self, data):
- if "vision_hidden_states" not in data:
+ if 'vision_hidden_states' not in data:
dtype = self.llm.model.embed_tokens.embedding_table.dtype
device = None
- tgt_sizes = data["tgt_sizes"]
- pixel_values_list = data["pixel_values"]
+ tgt_sizes = data['tgt_sizes']
+ pixel_values_list = data['pixel_values']
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
@@ -106,16 +107,7 @@ def get_vllm_embedding(self, data):
max_length_w = max([i.shape[1] for i in all_pixel_values])
for i in range(len(all_pixel_values)):
if all_pixel_values[i].shape[0] < max_length_h or all_pixel_values[i].shape[1] < max_length_w:
- all_pixel_values[i] = ops.pad(
- all_pixel_values[i],
- (
- 0,
- max_length_w - all_pixel_values[i].shape[1],
- 0,
- max_length_h - all_pixel_values[i].shape[0],
- ),
- value=0.0,
- )
+ all_pixel_values[i] = ops.pad(all_pixel_values[i], (0, max_length_w - all_pixel_values[i].shape[1], 0, max_length_h - all_pixel_values[i].shape[0]), value=0.0)
all_pixel_values = ops.stack(all_pixel_values)
B, L, _ = all_pixel_values.shape
@@ -123,7 +115,7 @@ def get_vllm_embedding(self, data):
patch_attn_mask = ops.zeros(Tensor((B, 1, int(max_patches))), dtype=ms.bool_)
for i in range(B):
- patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+ patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.astype(dtype)
@@ -132,33 +124,28 @@ def get_vllm_embedding(self, data):
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
- tmp_hs = self.vpm(
- all_pixel_values[start_idx:end_idx],
- patch_attention_mask=patch_attn_mask[start_idx:end_idx],
- tgt_sizes=tgt_sizes[start_idx:end_idx],
- ).last_hidden_state
+ tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
hs.append(tmp_hs)
vision_embedding = ops.cat(hs, axis=0)
else:
- vision_embedding = self.vpm(
- all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
- ).last_hidden_state
+ vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
- vision_hidden_states.append(vision_embedding[start : start + img_cnt])
+ vision_hidden_states.append(vision_embedding[start: start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
- else: # no image
+ else: # no image
if self.training:
- dummy_image = ops.zeros((1, 3, 224, 224), dtype=dtype)
- tgt_sizes = ms.Tensor(
- [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
- ).astype(ms.int32)
+ dummy_image = ops.zeros(
+ (1, 3, 224, 224),
+ dtype=dtype
+ )
+ tgt_sizes = ms.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).astype(ms.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
@@ -166,16 +153,15 @@ def get_vllm_embedding(self, data):
vision_hidden_states.append(dummy_feature)
else:
- vision_hidden_states = data["vision_hidden_states"]
+ vision_hidden_states = data['vision_hidden_states']
- if hasattr(self.llm.config, "scale_emb"):
- vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
+ if hasattr(self.llm.config, 'scale_emb'):
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
else:
- vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
- vision_hidden_states = [
- i.astype(vllm_embedding.dtype) if isinstance(i, ms.Tensor) else i for i in vision_hidden_states
- ]
+ vision_hidden_states = [i.astype(vllm_embedding.dtype) if isinstance(
+ i, ms.Tensor) else i for i in vision_hidden_states]
# bs = len(data['input_ids'])
# for i in range(bs):
@@ -202,7 +188,13 @@ def construct(self, data, **kwargs):
position_ids = position_ids.long()
with _no_grad():
- return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
+ return self.llm(
+ input_ids=None,
+ position_ids=position_ids,
+ inputs_embeds=vllm_embedding,
+ labels=data["labels"],
+ **kwargs
+ )
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
@@ -211,7 +203,7 @@ def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, *
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
- **kwargs,
+ **kwargs
)
if decode_text:
return self._decode_text(output, tokenizer)
@@ -221,10 +213,10 @@ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
- "inputs_embeds": inputs_embeds,
- "pad_token_id": 0,
- "eos_token_id": terminators,
- "streamer": streamer,
+ 'inputs_embeds': inputs_embeds,
+ 'pad_token_id': 0,
+ 'eos_token_id': terminators,
+ 'streamer': streamer
}
generation_kwargs.update(kwargs)
@@ -257,7 +249,7 @@ def generate(
return_vision_hidden_states=False,
stream=False,
decode_text=False,
- **kwargs,
+ **kwargs
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
@@ -269,7 +261,7 @@ def generate(
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
- model_inputs["tgt_sizes"] = tgt_sizes
+ model_inputs['tgt_sizes'] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
@@ -282,9 +274,7 @@ def generate(
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
else:
- result = self._decode(
- model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
- )
+ result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
if return_vision_hidden_states:
return result, vision_hidden_states
@@ -302,11 +292,11 @@ def chat(
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
- system_prompt="",
+ system_prompt='',
stream=False,
max_slice_nums=None,
use_image_id=None,
- **kwargs,
+ **kwargs
):
if isinstance(msgs[0], list):
batched = True
@@ -329,21 +319,11 @@ def chat(
self.processor = MiniCPMVProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
processor = self.processor
- assert (
- self.config.query_num == processor.image_processor.image_feature_size
- ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert (
- self.config.patch_size == processor.image_processor.patch_size
- ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert (
- self.config.use_image_id == processor.image_processor.use_image_id
- ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert (
- self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
- ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert (
- self.config.slice_mode == processor.image_processor.slice_mode
- ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
prompts_lists = []
input_images_lists = []
@@ -377,12 +357,10 @@ def chat(
msg["content"] = "\n".join(cur_msgs)
if system_prompt:
- sys_msg = {"role": "system", "content": system_prompt}
+ sys_msg = {'role': 'system', 'content': system_prompt}
copy_msgs = [sys_msg] + copy_msgs
- prompts_lists.append(
- processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
- )
+ prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
input_images_lists.append(images)
inputs = processor(
@@ -392,7 +370,7 @@ def chat(
use_image_id=use_image_id,
return_tensors="ms",
max_length=max_inp_length,
- image_processor=image_processor,
+ image_processor=image_processor
)
if sampling:
@@ -401,7 +379,7 @@ def chat(
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
- "repetition_penalty": 1.05,
+ "repetition_penalty": 1.05
}
else:
generation_config = {
@@ -410,9 +388,11 @@ def chat(
}
if min_new_tokens > 0:
- generation_config["min_new_tokens"] = min_new_tokens
+ generation_config['min_new_tokens'] = min_new_tokens
- generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
+ generation_config.update(
+ (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
+ )
inputs.pop("image_sizes")
# with torch.inference_mode():
@@ -423,17 +403,15 @@ def chat(
vision_hidden_states=vision_hidden_states,
stream=stream,
decode_text=True,
- **generation_config,
+ **generation_config
)
if stream:
-
def stream_gen():
for text in res:
for term in self.terminators:
- text = text.replace(term, "")
+ text = text.replace(term, '')
yield text
-
return stream_gen()
else:
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
index 2abd16cd80..99d2f0ab49 100644
--- a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -20,15 +20,16 @@
import os
import warnings
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Any, Optional, Tuple, Union
+import numpy as np
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
import mindspore as ms
# import torch.utils.checkpoint
-from mindspore import nn, ops
+from mindspore import Parameter, Tensor, nn, ops
from mindspore.ops.operations.nn_ops import FlashAttentionScore as FlashAttention
from ...activations import ACT2FN
@@ -36,12 +37,13 @@
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import MSPreTrainedModel
+from ...mindspore_adapter import recompute_except_output
+
# from torch.nn.init import _calculate_fan_in_and_fan_out
logger = logging.get_logger(__name__)
-
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
@@ -229,16 +231,13 @@ def trunc_normal_tf_(
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
-
def _calculate_fan_in_and_fan_out(arr):
# 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。
shape = arr.shape
dimensions = len(shape)
if dimensions < 2:
- raise ValueError(
- "'fan_in' and 'fan_out' can not be computed for arr with fewer than"
- " 2 dimensions, but got dimensions {}.".format(dimensions)
- )
+ raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
+ " 2 dimensions, but got dimensions {}.".format(dimensions))
if dimensions == 2: # Linear
fan_in = shape[1]
fan_out = shape[0]
@@ -252,7 +251,6 @@ def _calculate_fan_in_and_fan_out(arr):
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
-
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
@@ -337,9 +335,7 @@ def __init__(self, config: SiglipVisionConfig):
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- def construct(
- self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor] = None
- ) -> ms.Tensor:
+ def construct(self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor]=None) -> ms.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values)
@@ -394,7 +390,7 @@ def __init__(self, config):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
- self.scale = self.head_dim**-0.5
+ self.scale = self.head_dim ** -0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
@@ -422,8 +418,8 @@ def construct(
k_v_seq_len = key_states.shape[-2]
- query_states = ops.mul(query_states, self.scale**0.5)
- key_states = ops.mul(key_states, self.scale**0.5)
+ query_states = ops.mul(query_states, self.scale ** 0.5)
+ key_states = ops.mul(key_states, self.scale ** 0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -627,7 +623,9 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
- cu_seqlens_q = ops.arange(batch_size + 1, dtype=ms.int32) # There is a memcpy here, that is very bad.
+ cu_seqlens_q = ops.arange(
+ batch_size + 1, dtype=ms.int32
+ ) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
@@ -644,21 +642,22 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
-
class SiglipFlashAttention(SiglipAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
-
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Hack to make sure we don't use a causal mask
dropout_rate = self.dropout if self.training else 0.0
self.flash_attention = FlashAttention(
- scale_value=self.head_dim**-0.5, head_num=self.head_dim, input_layout="BSH", keep_prob=1 - dropout_rate
+ scale_value=self.head_dim**-0.5,
+ head_num=self.head_dim,
+ input_layout="BSH",
+ keep_prob=1-dropout_rate
)
def construct(
@@ -732,7 +731,9 @@ def construct(
value_states = value_states.to(target_dtype)
# implement flash attention
- attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, attention_mask)[3]
+ attn_output = self.flash_attention(
+ query_states, key_states, value_states, None, None, None, attention_mask
+ )[3]
# attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
@@ -742,7 +743,6 @@ def construct(
return attn_output, attn_weights
-
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Cell):
def __init__(self, config):
@@ -766,10 +766,14 @@ def __init__(self, config: SiglipVisionConfig):
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention = config._attn_implementation == "flash_attention"
- self.self_attn = SiglipAttention(config) if not self._use_flash_attention else SiglipFlashAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+ self.self_attn = (
+ SiglipAttention(config)
+ if not self._use_flash_attention
+ else SiglipFlashAttention(config)
+ )
+ self.layer_norm1 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+ self.layer_norm2 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
# add recompute
# self.self_attn.recompute()
@@ -906,6 +910,9 @@ def __init__(self, config: SiglipVisionConfig):
# recompute
for layer in self.layers:
layer.recompute()
+ # for layer in self.layers:
+ # for name, cell in layer.name_cells().items():
+ # recompute_except_output(cell)
# Ignore copy
def construct(
@@ -973,8 +980,9 @@ def construct(
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
- return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
-
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
@@ -988,7 +996,7 @@ def __init__(self, config: SiglipVisionConfig):
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, epsilon=config.layer_norm_eps)
+ self.post_layernorm = nn.LayerNorm((embed_dim,), epsilon=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention = config._attn_implementation == "flash_attention"
@@ -998,9 +1006,11 @@ def __init__(self, config: SiglipVisionConfig):
# recompute
# self.encoder.recompute()
+
def get_input_embeddings(self) -> nn.Cell:
return self.embeddings.patch_embedding
+
def construct(
self,
pixel_values,
@@ -1030,16 +1040,14 @@ def construct(
dtype=ms.bool_,
)
- hidden_states = self.embeddings(
- pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes
- )
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not ops.any(~patch_attention_mask):
- attention_mask = None
+ attention_mask=None
else:
attention_mask = (
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
diff --git a/mindone/transformers/models/minicpm_v/processing_minicpmv.py b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
index 2cc0a57140..d9da9f7882 100644
--- a/mindone/transformers/models/minicpm_v/processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
@@ -17,20 +17,19 @@
"""
import re
-from typing import List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import numpy as np
+from transformers.utils import TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-from transformers.utils import TensorType
import mindspore as ms
-from mindspore import Tensor, ops
+from mindspore import Parameter, Tensor, nn, ops
from ...processing_utils import ProcessorMixin
from .image_processing_minicpmv import MiniCPMVBatchFeature, MiniCPMVImageProcessor
-
class MiniCPMVProcessor(ProcessorMixin):
r"""
Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
@@ -62,21 +61,12 @@ def __call__(
use_image_id: bool = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
image_processor=None,
- **kwargs,
+ **kwargs
) -> MiniCPMVBatchFeature:
+
if images is not None:
- image_inputs = image_processor.preprocess(
- images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
- )
- return self._convert_images_texts_to_inputs(
- image_inputs,
- text,
- max_slice_nums=max_slice_nums,
- use_image_id=use_image_id,
- max_length=max_length,
- image_processor=image_processor,
- **kwargs,
- )
+ image_inputs = image_processor.preprocess(images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors)
+ return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, image_processor=image_processor, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
@@ -106,13 +96,13 @@ def decode(self, *args, **kwargs):
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
result = result[1:]
- if result[-1] == self.tokenizer.eos_id or (
- hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
- ):
+ if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
result = result[:-1]
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
- def _convert(self, input_str, max_inp_length: Optional[int] = None):
+ def _convert(
+ self, input_str, max_inp_length: Optional[int] = None
+ ):
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
input_ids = self.tokenizer.encode(input_str)
else:
@@ -122,8 +112,8 @@ def _convert(self, input_str, max_inp_length: Optional[int] = None):
input_ids = ms.Tensor(input_ids, dtype=ms.int32)
# FIXME ops.where issue
- start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
- end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
+ # start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
+ # end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
# image_start_tokens = ops.where(start_cond)[0]
# image_start_tokens += 1
# image_end_tokens = ops.where(end_cond)[0]
@@ -150,21 +140,19 @@ def _convert(self, input_str, max_inp_length: Optional[int] = None):
return input_ids, image_bounds
def _convert_images_texts_to_inputs(
- self,
- images,
- texts: Union[str, List[str]],
- truncation=None,
- max_length=None,
- max_slice_nums=None,
- use_image_id=None,
- return_tensors=None,
- image_processor=None,
- **kwargs,
- ):
+ self,
+ images,
+ texts: Union[str, List[str]],
+ truncation=None,
+ max_length=None,
+ max_slice_nums=None,
+ use_image_id=None,
+ return_tensors=None,
+ image_processor=None,
+ **kwargs
+ ):
if images is None or not len(images):
- model_inputs = self.tokenizer(
- texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
- )
+ model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs)
return MiniCPMVBatchFeature(data={**model_inputs})
pattern = "(./)"
@@ -180,32 +168,33 @@ def _convert_images_texts_to_inputs(
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
- final_text = (
- final_text
- + text_chunks[i]
- + image_processor.get_slice_image_placeholder(
- image_sizes[index][i], i, max_slice_nums, use_image_id
+ final_text = final_text + text_chunks[i] + \
+ image_processor.get_slice_image_placeholder(
+ image_sizes[index][i],
+ i,
+ max_slice_nums,
+ use_image_id
)
- )
final_text += text_chunks[-1]
input_ids, image_bounds = self._convert(final_text, max_length)
input_ids_list.append(input_ids)
image_bounds_list.append(image_bounds)
- padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
+ padded_input_ids, padding_lengths = self.pad(
+ input_ids_list,
+ padding_side="left"
+ )
for i, length in enumerate(padding_lengths):
image_bounds_list[i] = image_bounds_list[i] + length
attention_mask = padded_input_ids.ne(0)
- return MiniCPMVBatchFeature(
- data={
- "input_ids": padded_input_ids,
- "attention_mask": attention_mask,
- "pixel_values": images,
- "image_sizes": image_sizes,
- "image_bound": image_bounds_list,
- "tgt_sizes": tgt_sizes,
- }
- )
+ return MiniCPMVBatchFeature(data={
+ "input_ids": padded_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": images,
+ "image_sizes": image_sizes,
+ "image_bound": image_bounds_list,
+ "tgt_sizes": tgt_sizes
+ })
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
@@ -214,6 +203,7 @@ def model_input_names(self):
image_processor_input_names = MiniCPMVImageProcessor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(inputs[0], list):
@@ -242,7 +232,10 @@ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
return ops.stack([item for item in items], axis=0), [0] * batch_size
tensor = ops.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
- tensor = ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
+ tensor = (
+ ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
+ + padding_value
+ )
padding_length = []
for i, item in enumerate(items):
diff --git a/mindone/transformers/models/minicpm_v/resampler.py b/mindone/transformers/models/minicpm_v/resampler.py
index d414422479..7d0d4222f3 100644
--- a/mindone/transformers/models/minicpm_v/resampler.py
+++ b/mindone/transformers/models/minicpm_v/resampler.py
@@ -56,10 +56,10 @@ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
- out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
+ out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
@@ -77,14 +77,14 @@ class Resampler(nn.Cell):
"""
def __init__(
- self,
- num_queries,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
- adaptive=False,
- max_size=(70, 70),
+ self,
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim=None,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
+ adaptive=False,
+ max_size=(70, 70),
):
super().__init__()
self.num_queries = num_queries
@@ -105,7 +105,7 @@ def __init__(
self.ln_kv = norm_layer((embed_dim,))
self.ln_post = norm_layer((embed_dim,))
- self.proj = Parameter((embed_dim**-0.5) * ops.randn(embed_dim, embed_dim))
+ self.proj = Parameter((embed_dim ** -0.5) * ops.randn(embed_dim, embed_dim))
self._set_2d_pos_cache(self.max_size)
@@ -125,7 +125,7 @@ def _adjust_pos_cache(self, tgt_sizes):
def _init_weights(self, m):
if isinstance(m, nn.Dense):
- trunc_normal_(m.weight, std=0.02)
+ trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Dense) and m.bias is not None:
Zero(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
@@ -150,9 +150,8 @@ def construct(self, x, tgt_sizes=None):
tgt_h, tgt_w = tgt_sizes[i]
shape_0 = tgt_h * tgt_w
pos_embed.append(
- self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)
- ) # patches * D
- key_padding_mask[i, patch_len[i] :] = True
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)) # patches * D
+ key_padding_mask[i, patch_len[i]:] = True
# FIXME how to replace torch.nn.utils.rnn.pad_sequence
# pos_embed = torch.nn.utils.rnn.pad_sequence(
@@ -161,11 +160,9 @@ def construct(self, x, tgt_sizes=None):
max_length_w = max([i.shape[1] for i in pos_embed])
for i in range(len(pos_embed)):
if pos_embed[i].shape[0] < max_length_h or pos_embed[i].shape[1] < max_length_w:
- pos_embed[i] = ops.pad(
- pos_embed[i],
- (0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
- value=0.0,
- )
+ pos_embed[i] = ops.pad(pos_embed[i], (
+ 0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
+ value=0.0)
pos_embed = ops.stack(pos_embed)
pos_embed = pos_embed.permute(1, 0, 2)
@@ -175,11 +172,10 @@ def construct(self, x, tgt_sizes=None):
q = self.ln_q(self.query) # Q * D
out = self.attn(
- self._repeat(q, bs), # Q * B * D
+ q.unsqueeze(1).tile((1, bs, 1)), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
- key_padding_mask=key_padding_mask,
- )[0]
+ key_padding_mask=key_padding_mask)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
@@ -192,38 +188,25 @@ def _repeat(self, query, N: int):
class MultiheadAttention(nn.MultiheadAttention):
- def __init__(
- self,
- embed_dim,
- num_heads,
- dropout=0.0,
- bias=True,
- add_bias_kv=False,
- add_zero_attn=False,
- kdim=None,
- vdim=None,
- batch_first=False,
- dtype=None,
- ):
- super().__init__(
- embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, dtype
- )
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=None):
+ super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first,
+ dtype)
# rewrite out_proj layer,with nn.Linear
self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias)
def construct(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False,
- ) -> Tuple[Tensor, Optional[Tensor]]:
- why_not_fast_path = ""
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
+ why_not_fast_path = ''
# if ((attn_mask is not None and torch.is_floating_point(attn_mask))
# or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
# why_not_fast_path = "floating-point masks are not supported for fast path."
@@ -235,7 +218,7 @@ def construct(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype,
+ target_type=query.dtype
)
attn_mask = _canonical_mask(
@@ -255,16 +238,12 @@ def construct(
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
- why_not_fast_path = (
- f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
- )
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif self.in_proj_weight is None:
why_not_fast_path = "in_proj_weight was None"
elif query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
- why_not_fast_path = (
- f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
- )
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif (self.num_heads % 2) != 0:
@@ -325,8 +304,7 @@ def construct(
merged_mask,
need_weights,
average_attn_weights,
- mask_type,
- )
+ mask_type)
# any_nested = query.is_nested or key.is_nested or value.is_nested
# assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
@@ -345,84 +323,62 @@ def construct(
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query,
- key,
- value,
- self.embed_dim,
- self.num_heads,
- self.in_proj_weight,
- self.in_proj_bias,
- self.bias_k,
- self.bias_v,
- self.add_zero_attn,
- self.dropout,
- self.out_proj.weight,
- self.out_proj.bias,
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
- key_padding_mask=key_padding_mask,
- need_weights=need_weights,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
- q_proj_weight=self.q_proj_weight,
- k_proj_weight=self.k_proj_weight,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
- is_causal=is_causal,
- )
+ is_causal=is_causal)
else:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query,
- key,
- value,
- self.embed_dim,
- self.num_heads,
- self.in_proj_weight,
- self.in_proj_bias,
- self.bias_k,
- self.bias_v,
- self.add_zero_attn,
- self.dropout,
- self.out_proj.weight,
- self.out_proj.bias,
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
- is_causal=is_causal,
- )
+ is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def multi_head_attention_forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- embed_dim_to_check: int,
- num_heads: int,
- in_proj_weight: Optional[Tensor],
- in_proj_bias: Optional[Tensor],
- bias_k: Optional[Tensor],
- bias_v: Optional[Tensor],
- add_zero_attn: bool,
- dropout_p: float,
- out_proj_weight: Tensor,
- out_proj_bias: Optional[Tensor],
- training: bool = True,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- use_separate_proj_weight: bool = False,
- q_proj_weight: Optional[Tensor] = None,
- k_proj_weight: Optional[Tensor] = None,
- v_proj_weight: Optional[Tensor] = None,
- static_k: Optional[Tensor] = None,
- static_v: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False,
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[Tensor],
+ in_proj_bias: Optional[Tensor],
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Optional[Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
# FIXME: logic passed
@@ -479,7 +435,7 @@ def multi_head_attention_forward(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype,
+ target_type=query.dtype
)
if is_causal and attn_mask is None:
@@ -510,20 +466,18 @@ def multi_head_attention_forward(
# longer causal.
is_causal = False
- assert (
- embed_dim == embed_dim_to_check
- ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ assert embed_dim == embed_dim_to_check, \
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, ms.Tensor):
# embed_dim can be a tensor when JIT tracing
- head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
- assert (
- key.shape[:2] == value.shape[:2]
- ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ assert key.shape[:2] == value.shape[:2], \
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
@@ -551,15 +505,13 @@ def multi_head_attention_forward(
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
- )
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
- )
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
@@ -587,23 +539,19 @@ def multi_head_attention_forward(
k = k.view(k.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert (
- static_k.shape[0] == bsz * num_heads
- ), f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
- assert (
- static_k.shape[2] == head_dim
- ), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
+ assert static_k.shape[0] == bsz * num_heads, \
+ f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
+ assert static_k.shape[2] == head_dim, \
+ f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert (
- static_v.shape[0] == bsz * num_heads
- ), f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
- assert (
- static_v.shape[2] == head_dim
- ), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
+ assert static_v.shape[0] == bsz * num_heads, \
+ f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
+ assert static_v.shape[2] == head_dim, \
+ f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
v = static_v
# add zero attention along batch dimension (now first)
@@ -621,15 +569,10 @@ def multi_head_attention_forward(
# merge key padding and attention masks
if key_padding_mask is not None:
- assert key_padding_mask.shape == (
- bsz,
- src_len,
- ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
- key_padding_mask = (
- key_padding_mask.view(bsz, 1, 1, src_len)
- .expand(-1, num_heads, -1, -1)
- .reshape(bsz * num_heads, 1, src_len)
- )
+ assert key_padding_mask.shape == (bsz, src_len), \
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
+ broadcast_to((-1, num_heads, -1, -1)).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
@@ -698,14 +641,8 @@ def multi_head_attention_forward(
return attn_output, None
-def _mha_shape_check(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor],
- attn_mask: Optional[Tensor],
- num_heads: int,
-):
+def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
+ key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
# and returns if the input is batched or not.
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
@@ -714,65 +651,58 @@ def _mha_shape_check(
if query.dim() == 3:
# Batched Inputs
is_batched = True
- assert key.dim() == 3 and value.dim() == 3, (
- "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
- )
+ assert key.dim() == 3 and value.dim() == 3, \
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 2, (
- "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
- f" but found {key_padding_mask.dim()}-D tensor instead"
- )
+ assert key_padding_mask.dim() == 2, \
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), (
- "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead"
- )
+ assert attn_mask.dim() in (2, 3), \
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
elif query.dim() == 2:
# Unbatched Inputs
is_batched = False
- assert key.dim() == 2 and value.dim() == 2, (
- "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
- )
+ assert key.dim() == 2 and value.dim() == 2, \
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 1, (
- "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
- f" but found {key_padding_mask.dim()}-D tensor instead"
- )
+ assert key_padding_mask.dim() == 1, \
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), (
- "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead"
- )
+ assert attn_mask.dim() in (2, 3), \
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
- assert (
- attn_mask.shape == expected_shape
- ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
+ assert attn_mask.shape == expected_shape, \
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
else:
raise AssertionError(
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
- )
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
return is_batched
def _canonical_mask(
- mask: Optional[Tensor],
- mask_name: str,
- other_type: Optional,
- other_name: str,
- target_type: None,
- check_other: bool = True,
+ mask: Optional[Tensor],
+ mask_name: str,
+ other_type: Optional,
+ other_name: str,
+ target_type: None,
+ check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = ops.is_floating_point(mask)
if _mask_dtype != ms.bool_ and not _mask_is_float:
- raise AssertionError(f"only bool and floating types of {mask_name} are supported")
+ raise AssertionError(
+ f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
@@ -780,7 +710,10 @@ def _canonical_mask(
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
- mask = ops.zeros_like(mask, dtype=target_type).masked_fill(mask, float("-inf"))
+ mask = (
+ ops.zeros_like(mask, dtype=target_type)
+ .masked_fill(mask, float("-inf"))
+ )
return mask
@@ -793,11 +726,11 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional:
def _in_projection_packed(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w: Tensor,
- b: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w: Tensor,
+ b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
@@ -853,15 +786,15 @@ def _in_projection_packed(
def _in_projection(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w_q: Tensor,
- w_k: Tensor,
- w_v: Tensor,
- b_q: Optional[Tensor] = None,
- b_k: Optional[Tensor] = None,
- b_v: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w_q: Tensor,
+ w_k: Tensor,
+ w_v: Tensor,
+ b_q: Optional[Tensor] = None,
+ b_k: Optional[Tensor] = None,
+ b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Performs the in-projection step of the attention operation. This is simply
diff --git a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
index e41bff8bca..f6d84bd25c 100644
--- a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
+++ b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
@@ -1,4 +1,4 @@
-from mindnlp.transformers import AutoTokenizer
+from transformers import AutoTokenizer
from ..qwen2 import Qwen2TokenizerFast
@@ -57,7 +57,7 @@ def im_id_end_id(self):
@property
def newline_id(self):
- return self.convert_tokens_to_ids("\n")
+ return self.convert_tokens_to_ids('\n')
@staticmethod
def escape(text: str) -> str:
@@ -67,5 +67,4 @@ def escape(text: str) -> str:
def unescape(text: str) -> str:
return text
-
AutoTokenizer.register("MiniCPMVTokenizerFast", MiniCPMVTokenizerFast)
diff --git a/mindone/transformers/models/qwen2/__init__.py b/mindone/transformers/models/qwen2/__init__.py
index aa7e109cf0..be2d5916fd 100644
--- a/mindone/transformers/models/qwen2/__init__.py
+++ b/mindone/transformers/models/qwen2/__init__.py
@@ -21,6 +21,7 @@
}
+
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
_import_structure["modeling_qwen2"] = [
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
index 08df790475..e1eb9df146 100644
--- a/mindone/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -19,19 +19,18 @@
# limitations under the License.
"""Mindspore Qwen2 model."""
+import math
from typing import List, Optional, Tuple, Union
import numpy as np
-from transformers import logging
import mindspore as ms
-from mindspore import Parameter, nn, ops
+from mindspore import nn, ops, Tensor, Parameter
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -41,6 +40,11 @@
from ...modeling_utils import MSPreTrainedModel
from .configuration_qwen2 import Qwen2Config
+from transformers import logging
+
+from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
@@ -66,6 +70,7 @@ def dtype_to_min(dtype):
raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}")
+
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: ms.Tensor,
@@ -107,7 +112,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if sequence_length != 1:
causal_mask = ops.triu(causal_mask, diagonal=1)
causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1))
if attention_mask is not None:
# causal_mask = causal_mask # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
@@ -121,11 +126,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = causal_mask.masked_fill(padding_mask, min_dtype)
else:
causal_mask = ops.cat(
- [
- ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
- ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length),
- ],
- axis=-1,
+ [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
+ ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)],
+ axis=-1
)
return causal_mask
@@ -164,7 +167,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.inv_freq = inv_freq
# Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(seq_len=max_position_embeddings, device=None, dtype=ms.float32)
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=None, dtype=ms.float32
+ )
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
@@ -248,7 +253,7 @@ def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim))
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@@ -295,7 +300,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
base=self.rope_theta,
)
- self.scale = self.head_dim**-0.5
+ self.scale = self.head_dim ** -0.5
def construct(
self,
@@ -337,8 +342,8 @@ def construct(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- query_states = ops.mul(query_states, self.scale**0.5)
- key_states = ops.mul(key_states, self.scale**0.5)
+ query_states = ops.mul(query_states, self.scale ** 0.5)
+ key_states = ops.mul(key_states, self.scale ** 0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -854,6 +859,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value
+
def construct(
self,
input_ids: ms.Tensor = None,
@@ -901,7 +907,9 @@ def construct(
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
+ cache_position = ops.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
+ )
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -1028,7 +1036,11 @@ def _update_causal_mask(
batch_size=input_tensor.shape[0],
)
- if self.config._attn_implementation == "sdpa" and attention_mask is not None and not output_attentions:
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and not output_attentions
+ ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
@@ -1067,6 +1079,7 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model
+
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1178,13 +1191,13 @@ def prepare_inputs_for_generation(
# input_ids = input_ids[:, :cache_position.shape[0]]
if inputs_embeds is not None: # Exception 1
if 0 not in input_ids.shape:
- input_ids = input_ids[:, -cache_position.shape[0] :]
+ input_ids = input_ids[:, -cache_position.shape[0]:]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = ops.index_select(input_ids, -1, cache_position)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids = attention_mask.cumsum(-1).long() - 1
position_ids.masked_fill(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
@@ -1232,6 +1245,7 @@ def prepare_inputs_for_generation(
return model_inputs
+
class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1248,6 +1262,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
+
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1338,6 +1353,7 @@ def construct(
)
+
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
@@ -1362,6 +1378,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
+
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -1412,4 +1429,4 @@ def construct(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- )
+ )
\ No newline at end of file
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2.py b/mindone/transformers/models/qwen2/tokenization_qwen2.py
index b13046fbbf..c5cff300a2 100644
--- a/mindone/transformers/models/qwen2/tokenization_qwen2.py
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2.py
@@ -49,7 +49,9 @@ def bytes_to_unicode():
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
- bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
cs = bs[:]
n = 0
for b in range(2**8):
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py b/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
index 5ed3c74bc4..176a37344c 100644
--- a/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2_fast.py
@@ -16,8 +16,8 @@
from typing import Optional, Tuple
-from mindnlp.transformers.tokenization_utils import AddedToken
-from mindnlp.transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.tokenization_utils import AddedToken
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
from .tokenization_qwen2 import Qwen2Tokenizer
From 70d38f306f525aa064a6186fa915c930b5aaa4e8 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Wed, 20 Nov 2024 10:10:28 +0800
Subject: [PATCH 5/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
examples/minicpm_v/finetune/dataset.py | 149 ++++----
examples/minicpm_v/finetune/finetune.py | 85 ++---
.../transformers/models/minicpm_v/__init__.py | 2 +-
.../models/minicpm_v/configuration_minicpm.py | 3 +-
.../minicpm_v/image_processing_minicpmv.py | 79 ++--
.../models/minicpm_v/modeling_minicpmv.py | 147 +++++---
.../models/minicpm_v/modeling_navit_siglip.py | 58 ++-
.../models/minicpm_v/processing_minicpmv.py | 95 ++---
.../models/minicpm_v/resampler.py | 355 +++++++++++-------
.../minicpm_v/tokenization_minicpmv_fast.py | 3 +-
mindone/transformers/models/qwen2/__init__.py | 1 -
.../models/qwen2/modeling_qwen2.py | 49 +--
.../models/qwen2/tokenization_qwen2.py | 4 +-
13 files changed, 524 insertions(+), 506 deletions(-)
diff --git a/examples/minicpm_v/finetune/dataset.py b/examples/minicpm_v/finetune/dataset.py
index fa5a3b312c..eb82b9056b 100644
--- a/examples/minicpm_v/finetune/dataset.py
+++ b/examples/minicpm_v/finetune/dataset.py
@@ -31,6 +31,7 @@
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
+
class SupervisedDataset:
"""Dataset for supervised fine-tuning."""
@@ -53,7 +54,7 @@ def __init__(
self.slice_config = slice_config
self.llm_type = llm_type
self.patch_size = patch_size
- self.query_nums=query_nums
+ self.query_nums = query_nums
self.batch_vision = batch_vision
self.max_length = max_length
# self.dataset_column_names = ["input_ids", "position_ids", "labels", "attention_mask", "pixel_values", "tgt_sizes", "image_bound"]
@@ -66,10 +67,13 @@ def __len__(self):
def __getitem__(self, idx, retry_count=3):
try:
if isinstance(self.raw_data[idx]["image"], str):
- images_dict = { "" : Image.open(self.raw_data[idx]["image"]).convert("RGB") }
+ images_dict = {"": Image.open(self.raw_data[idx]["image"]).convert("RGB")}
elif isinstance(self.raw_data[idx]["image"], Dict):
### for multi-images input, the template for every image is , such as ,
- images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[idx]["image"].items()}
+ images_dict = {
+ img_name: Image.open(img_path).convert("RGB")
+ for img_name, img_path in self.raw_data[idx]["image"].items()
+ }
ret = preprocess(
images_dict,
@@ -81,7 +85,7 @@ def __getitem__(self, idx, retry_count=3):
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
- max_length=self.max_length
+ max_length=self.max_length,
)
ret = dict(
input_ids=ret["input_ids"],
@@ -93,7 +97,7 @@ def __getitem__(self, idx, retry_count=3):
image_bound=ret["image_bound"],
)
- ret = data_collator(ret, max_length = self.max_length)
+ ret = data_collator(ret, max_length=self.max_length)
except (EOFError, ValueError, OSError) as e:
# Log and handle EOFError and other file-related errors
@@ -114,6 +118,7 @@ def __getitem__(self, idx, retry_count=3):
# return (ret["input_ids"], ret["position_ids"], ret["labels"], np.ones_like(ret["input_ids"], dtype=np.bool_), ret["pixel_values"], ret["tgt_sizes"], ret["image_bound"])
return ret
+
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
# return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -172,24 +177,20 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
if llm_type == "llama3":
- input_ids, context, raw_msg = conversation_to_ids_llama3(
- conversation, tokenizer
- )
+ input_ids, context, raw_msg = conversation_to_ids_llama3(conversation, tokenizer)
elif llm_type == "qwen2":
- input_ids, context, raw_msg = conversation_to_ids_qwen2(
- conversation, tokenizer
- )
+ input_ids, context, raw_msg = conversation_to_ids_qwen2(conversation, tokenizer)
else:
- input_ids, context, raw_msg = conversation_to_ids_minicpm(
- conversation, tokenizer
- )
+ input_ids, context, raw_msg = conversation_to_ids_minicpm(conversation, tokenizer)
ids = np.hstack(input_ids, dtype=np.int32)
context = np.hstack(context, dtype=np.int8)
if input_ids.shape[-1] > max_length:
ids = ids[:max_length]
context = context[:max_length]
- logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
+ logger.warning(
+ f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated"
+ )
if np.all(context):
logger.error("No tokens available to compute loss.")
@@ -235,7 +236,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
- "position_ids": position_ids
+ "position_ids": position_ids,
}
@@ -276,24 +277,23 @@ def conversation_to_ids_llama3(conversation, tokenizer):
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
- conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template,
+ conversation,
+ tokenize=False,
+ add_generation_prompt=False,
+ chat_template=llama3_chat_template,
)
input_ids = tokenizer.apply_chat_template(
- conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template,
+ conversation,
+ tokenize=True,
+ add_generation_prompt=False,
+ chat_template=llama3_chat_template,
)
input_ids = np.array(input_ids)
- start_header_idxs = np.where(
- input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
- )[0]
- assistant_idxs = np.where(
- input_ids == tokenizer.convert_tokens_to_ids("assistant")
- )[0]
- end_header_idxs = np.where(
- input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
- )[0]
- eot_idxs = np.where(
- input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
+ start_header_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>"))[0]
+ assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))[0]
+ end_header_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>"))[0]
+ eot_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
@@ -302,7 +302,7 @@ def conversation_to_ids_llama3(conversation, tokenizer):
st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
for eot_idx in eot_idxs:
if eot_idx > st:
- context[st: eot_idx + 1] = 0
+ context[st : eot_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
@@ -323,35 +323,37 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
prefix = "user"
else:
prefix = "assistant"
- chat.append({"role":prefix, "content":message})
+ chat.append({"role": prefix, "content": message})
raw_msg += prefix + message
- assert set([i['role'] for i in chat]) & set(['assistant'])
+ assert set([i["role"] for i in chat]) & set(["assistant"])
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
input_ids = np.array(input_ids)
- start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
- assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
- end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
+ start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("<|im_start|>"))[0]
+ assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))[0]
+ end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids("<|im_end|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
- if assistant_idx-1 in set(start_idxs):
+ if assistant_idx - 1 in set(start_idxs):
st = assistant_idx + 1
for end_idx in end_idxs:
if end_idx > st:
- context[st: end_idx + 1] = 0
+ context[st : end_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
+
def trans_fn(x):
x = np.asarray(x).transpose((2, 0, 1))
- return (x-0.5*255)/(0.5*255)
+ return (x - 0.5 * 255) / (0.5 * 255)
+
def preprocess(
images_dict,
@@ -377,12 +379,10 @@ def preprocess(
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
- default_image_placeholder = (
- tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
- )
+ default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
new_schema = False
use_image_id = False
- if llm_type=='qwen2':
+ if llm_type == "qwen2":
new_schema = True
use_image_id = True
image_placeholder_dict = {}
@@ -403,15 +403,16 @@ def preprocess(
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
- image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
+ image_placeholder = (
+ f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
+ )
image_id_cnt += 1
- image_placeholder += get_grid_placeholder(
- tokenizer, best_grid, query_nums, new_schema = new_schema)
+ image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id:
- image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
+ image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
image_id_cnt += 1
else:
image_placeholder = default_image_placeholder
@@ -421,20 +422,16 @@ def preprocess(
if len(images_dict) == 1 and "" in images_dict:
if "" in conversations[0]["content"]:
- conversations[0]["content"] = conversations[0]["content"].replace(
- "", image_placeholder
- )
+ conversations[0]["content"] = conversations[0]["content"].replace("", image_placeholder)
else:
- conversations[0]["content"] = (
- image_placeholder + "\n" + conversations[0]["content"]
- )
+ conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"]
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
- pattern = r''
+ pattern = r""
new_conversations = []
for conversation in conversations:
- content = conversation['content']
- parts = re.split(f'({pattern})', content)
+ content = conversation["content"]
+ parts = re.split(f"({pattern})", content)
for i, part in enumerate(parts):
if not part.strip():
continue
@@ -443,7 +440,7 @@ def preprocess(
parts[i] = image_placeholder_dict[part]
else:
raise Exception(f"not found {part} in image dict")
- conversation['content'] = '\n'.join(parts)
+ conversation["content"] = "\n".join(parts)
new_conversations.append(conversation)
conversations = new_conversations
@@ -470,14 +467,11 @@ def preprocess(
return input_dict
-def slice_image(
- image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
-):
+def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
- ratio = original_width * original_height / \
- (scale_resolution * scale_resolution)
+ ratio = original_width * original_height / (scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
@@ -486,9 +480,7 @@ def slice_image(
if multiple <= 1 or never_split:
# dont need to slice, upsample
- best_size = find_best_resize(
- original_size, scale_resolution, patch_size, allow_upscale=True
- )
+ best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
else:
candidate_split_grids_nums = []
@@ -498,8 +490,7 @@ def slice_image(
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
- best_resize = find_best_resize(
- original_size, scale_resolution, patch_size)
+ best_resize = find_best_resize(original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
@@ -519,9 +510,7 @@ def slice_image(
best_grid = grid
min_error = error
- refine_size = get_refine_size(
- original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
- )
+ refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, best_grid)
@@ -544,9 +533,7 @@ def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=
return (best_width, best_height)
-def get_refine_size(
- original_size, grid, scale_resolution, patch_size, allow_upscale=False
-):
+def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
grid_x, grid_y = grid
@@ -587,13 +574,9 @@ def split_to_patches(image, grid):
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
if new_schema:
- image_placeholder = (
- tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
- )
+ image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
else:
- image_placeholder = (
- tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
- )
+ image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
cols = grid[0]
rows = grid[1]
@@ -604,10 +587,9 @@ def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
lines.append(image_placeholder)
slices.append("".join(lines))
if new_schema:
- slice_placeholder = '\n'.join(slices)
+ slice_placeholder = "\n".join(slices)
else:
- slice_placeholder = tokenizer.slice_start + \
- "\n".join(slices) + tokenizer.slice_end
+ slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
return slice_placeholder
@@ -639,9 +621,8 @@ def reshape_by_patch(image_tensor, patch_size):
patches = image_tensor.reshape(c, v_block_num, patch_size, h_block_num, patch_size)
patches = np.transpose(patches, (0, 2, 4, 1, 3))
- patches = patches.reshape(c*patch_size*patch_size, -1)
+ patches = patches.reshape(c * patch_size * patch_size, -1)
patches = patches.reshape(image_tensor.shape[0], patch_size, patch_size, -1)
- patches = patches.transpose((0, 1, 3, 2)).reshape(
- image_tensor.shape[0], patch_size, -1)
+ patches = patches.transpose((0, 1, 3, 2)).reshape(image_tensor.shape[0], patch_size, -1)
return patches
diff --git a/examples/minicpm_v/finetune/finetune.py b/examples/minicpm_v/finetune/finetune.py
index 50140ca723..a40502d30c 100644
--- a/examples/minicpm_v/finetune/finetune.py
+++ b/examples/minicpm_v/finetune/finetune.py
@@ -23,7 +23,9 @@
rank, rank_size = 0, 1
-ms.set_context(mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True, mempool_block_size="59GB", max_device_memory="59GB")
+ms.set_context(
+ mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True, mempool_block_size="59GB", max_device_memory="59GB"
+)
import transformers
from transformers import HfArgumentParser
@@ -36,12 +38,12 @@
sys.path.insert(0, mindone_lib_path)
from dataset import SupervisedDataset
-from mindone.transformers.trainer import Trainer
from transformers import AutoTokenizer
-from mindone.transformers.training_args import TrainingArguments
-from mindone.transformers.models.minicpm_v2_6 import MiniCPMV_v2_6
from mindone.transformers.mindspore_adapter import MindSporeArguments
+from mindone.transformers.models.minicpm_v2_6 import MiniCPMV_v2_6
+from mindone.transformers.trainer import Trainer
+from mindone.transformers.training_args import TrainingArguments
# from transformers.integrations import deepspeed
@@ -51,6 +53,7 @@
# ms.set_context(mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True)
# ms.set_context(mode=ms.context.PYNATIVE_MODE)
+
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
@@ -58,12 +61,9 @@ class ModelArguments:
@dataclass
class DataArguments:
- data_path: str = field(
- default=None, metadata={"help": "Path to the training data."}
- )
- eval_data_path: str = field(
- default=None, metadata={"help": "Path to the evaluation data."}
- )
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+ eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})
+
# @dataclass
# class TrainingArguments(TrainingArguments):
@@ -98,6 +98,7 @@ class LoraArguments:
lora_layers_to_transform: Optional[List[int]] = None
lora_layers_pattern: Optional[str] = None
+
@dataclass
class MyArguments(MindSporeArguments, TrainingArguments):
enable_flash_attention: bool = field(default=False)
@@ -107,9 +108,7 @@ class MyArguments(MindSporeArguments, TrainingArguments):
optim: str = field(default="adamw_mindspore")
model_max_length: int = field(
default=2048,
- metadata={
- "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
- },
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
@@ -119,7 +118,10 @@ class MyArguments(MindSporeArguments, TrainingArguments):
distributed: Optional[bool] = field(default=False)
amp_level: Optional[str] = field(default="O0")
+
local_rank = None
+
+
def rank0_print(*args):
if local_rank == 0:
print(*args)
@@ -128,7 +130,10 @@ def rank0_print(*args):
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
if trainer.args.should_save and trainer.args.local_rank == 0:
- trainer.save_model(output_dir,)
+ trainer.save_model(
+ output_dir,
+ )
+
# class ModifiedMapFunction(BaseMapFuction):
# def __call__(self, input_ids, position_ids, labels, attention_mask):
@@ -227,16 +232,16 @@ def make_supervised_data_module(
# ]
# )
+
def build_transform():
- IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
- [
- vision.Normalize(
- mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False
- ),
- ]
- )
+ [
+ vision.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False),
+ ]
+ )
+
def get_parameter_number(model):
trainable_params, all_param = 0, 0
@@ -253,7 +258,7 @@ def get_parameter_number(model):
num_params = np.prod(param.shape)
trainable_params += num_params
- return {'Trainable params': trainable_params}
+ return {"Trainable params": trainable_params}
local_rank = 0
@@ -261,9 +266,7 @@ def get_parameter_number(model):
def train():
global local_rank
- parser = HfArgumentParser(
- (ModelArguments, DataArguments, MyArguments, LoraArguments)
- )
+ parser = HfArgumentParser((ModelArguments, DataArguments, MyArguments, LoraArguments))
(
model_args,
@@ -275,11 +278,7 @@ def train():
# if getattr(training_args, "deepspeed", None) :
# training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
- compute_dtype = (
- ms.float16
- if training_args.fp16
- else (ms.bfloat16 if training_args.bf16 else ms.float32)
- )
+ compute_dtype = ms.float16 if training_args.fp16 else (ms.bfloat16 if training_args.bf16 else ms.float32)
# if training_args.distributed:
# init()
@@ -297,9 +296,7 @@ def train():
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0:
- logging.warning(
- "FSDP or ZeRO3 are not incompatible with QLoRA."
- )
+ logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")
model = MiniCPMV_v2_6.from_pretrained(
model_args.model_name_or_path,
@@ -324,9 +321,7 @@ def train():
# else:
# grad_reducer = None
- tokenizer = AutoTokenizer.from_pretrained(
- model_args.model_name_or_path, trust_remote_code=True
- )
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
if not training_args.tune_vision:
# model.vpm.set_train(False)
@@ -344,9 +339,9 @@ def train():
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
for name, param in model.llm.named_parameters():
param.requires_grad = False
- modules_to_save = ['embed_tokens','resampler']
+ modules_to_save = ["embed_tokens", "resampler"]
if training_args.tune_vision:
- modules_to_save.append('vpm')
+ modules_to_save.append("vpm")
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
@@ -356,9 +351,11 @@ def train():
layers_to_transform=lora_args.lora_layers_to_transform,
modules_to_save=modules_to_save,
)
- if not hasattr(model, 'get_input_embeddings'):
+ if not hasattr(model, "get_input_embeddings"):
+
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
+
model.get_input_embeddings = MethodType(get_input_embeddings, model)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
@@ -372,8 +369,7 @@ def get_input_embeddings(self):
llm_type = training_args.llm_type
- rank0_print(f'llm_type={llm_type}')
-
+ rank0_print(f"llm_type={llm_type}")
# Load data
if hasattr(model.config, "slice_config"):
@@ -402,7 +398,7 @@ def get_input_embeddings(self):
max_length=training_args.model_max_length,
)
- training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
trainer = Trainer(
model=model,
tokenizer=tokenizer,
@@ -413,10 +409,7 @@ def get_input_embeddings(self):
trainer.train()
# trainer.save_state()
- safe_save_model_for_hf_trainer(
- trainer=trainer,
- output_dir=training_args.output_dir,
- bias=lora_args.lora_bias)
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
if __name__ == "__main__":
diff --git a/mindone/transformers/models/minicpm_v/__init__.py b/mindone/transformers/models/minicpm_v/__init__.py
index d9273d2cf8..3d30bf0084 100644
--- a/mindone/transformers/models/minicpm_v/__init__.py
+++ b/mindone/transformers/models/minicpm_v/__init__.py
@@ -1,2 +1,2 @@
-from .modeling_minicpmv import MiniCPMV_v2_6
from .image_processing_minicpmv import MiniCPMVImageProcessor
+from .modeling_minicpmv import MiniCPMV_v2_6
diff --git a/mindone/transformers/models/minicpm_v/configuration_minicpm.py b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
index 063cfee91b..db1f383fc0 100644
--- a/mindone/transformers/models/minicpm_v/configuration_minicpm.py
+++ b/mindone/transformers/models/minicpm_v/configuration_minicpm.py
@@ -45,7 +45,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
return cls.from_dict(config_dict, **kwargs)
-
class MiniCPMVConfig(Qwen2Config):
model_type = "minicpmv"
keys_to_ignore_at_inference = ["past_key_values"]
@@ -58,7 +57,7 @@ class MiniCPMVConfig(Qwen2Config):
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
- "attn_implementation": "flash_attention"
+ "attn_implementation": "flash_attention",
}
def __init__(
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
index 7626964f1b..1aad29b744 100644
--- a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -5,7 +5,6 @@
import PIL
import PIL.Image
import PIL.ImageSequence
-from ...image_processing_utils import BaseImageProcessor, BatchFeature
from PIL import Image
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
@@ -23,6 +22,8 @@
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+
def recursive_converter(converter, value):
if isinstance(value, list):
@@ -104,12 +105,7 @@ def cast_tensor(v):
class MiniCPMVImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
- def __init__(
- self,
- max_slice_nums=9,
- scale_resolution=448,
- patch_size=14,
- **kwargs):
+ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
super().__init__(**kwargs)
self.max_slice_nums = max_slice_nums
self.scale_resolution = scale_resolution
@@ -131,14 +127,9 @@ def __init__(
def ensure_divide(self, length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
- def find_best_resize(self,
- original_size,
- scale_resolution,
- patch_size,
- allow_upscale=False):
+ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
- if (width * height >
- scale_resolution * scale_resolution) or allow_upscale:
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
@@ -146,12 +137,7 @@ def find_best_resize(self,
best_height = self.ensure_divide(height, patch_size)
return (best_width, best_height)
- def get_refine_size(self,
- original_size,
- grid,
- scale_resolution,
- patch_size,
- allow_upscale=False):
+ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
grid_x, grid_y = grid
@@ -161,10 +147,9 @@ def get_refine_size(self,
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
- best_grid_size = self.find_best_resize((grid_width, grid_height),
- scale_resolution,
- patch_size,
- allow_upscale=allow_upscale)
+ best_grid_size = self.find_best_resize(
+ (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
+ )
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
@@ -182,9 +167,7 @@ def split_to_patches(self, image, grid):
patches.append(images)
return patches
- def slice_image(
- self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
- ):
+ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
original_size = image.size
source_image = None
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
@@ -192,9 +175,7 @@ def slice_image(
if best_grid is None:
# dont need to slice, upsample
- best_size = self.find_best_resize(
- original_size, scale_resolution, patch_size, allow_upscale=True
- )
+ best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
else:
# source image, down-sampling and ensure divided by patch_size
@@ -212,9 +193,7 @@ def get_grid_placeholder(self, grid):
if grid is None:
return ""
slice_image_placeholder = (
- self.slice_start_token
- + self.unk_token * self.image_feature_size
- + self.slice_end_token
+ self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
)
cols = grid[0]
@@ -241,10 +220,7 @@ def get_sliced_images(self, image, max_slice_nums=None):
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
assert max_slice_nums > 0
source_image, patches, sliced_grid = self.slice_image(
- image,
- max_slice_nums, # default: 9
- self.scale_resolution, # default: 448
- self.patch_size # default: 14
+ image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
)
slice_images.append(source_image)
@@ -290,11 +266,7 @@ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=No
assert max_slice_nums > 0
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
- image_placeholder = (
- self.im_start_token
- + self.unk_token * self.image_feature_size
- + self.im_end_token
- )
+ image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
if use_image_id:
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
@@ -349,11 +321,7 @@ def reshape_by_patch(self, image):
w = image.shape[2]
image = image.reshape(1, c, h, w)
- patches = ops.unfold(
- image,
- (patch_size, patch_size),
- stride=(patch_size, patch_size)
- )
+ patches = ops.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
image = image.squeeze(axis=0)
@@ -362,12 +330,12 @@ def reshape_by_patch(self, image):
return patches.numpy()
def preprocess(
- self,
- images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
- do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
- max_slice_nums: int = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
+ self,
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
+ max_slice_nums: int = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
) -> MiniCPMVBatchFeature:
if isinstance(images, Image.Image):
images_list = [[images]]
@@ -412,7 +380,8 @@ def preprocess(
for slice_image in image_patches:
new_images.append(self.reshape_by_patch(slice_image))
tgt_sizes.append(
- np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
+ )
if tgt_sizes:
tgt_sizes = np.vstack(tgt_sizes)
@@ -422,7 +391,7 @@ def preprocess(
tgt_sizes_list.append(tgt_sizes)
return MiniCPMVBatchFeature(
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
- tensor_type=return_tensors
+ tensor_type=return_tensors,
)
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
index 003333ddcd..71cd628779 100644
--- a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -4,21 +4,21 @@
from threading import Thread
from typing import List, Optional
-from transformers import TextIteratorStreamer
from PIL import Image
+from transformers import TextIteratorStreamer
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Parameter, Tensor, _no_grad, nn, ops
from ..qwen2 import Qwen2ForCausalLM, Qwen2PreTrainedModel
from .configuration_minicpm import MiniCPMVConfig
+from .image_processing_minicpmv import MiniCPMVImageProcessor
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmv import MiniCPMVProcessor
-from .image_processing_minicpmv import MiniCPMVImageProcessor
from .resampler import Resampler
+
# from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
-from mindspore import _no_grad
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
@@ -34,21 +34,21 @@ def __init__(self, config):
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.processor = None
- self.terminators = ['<|im_end|>', '<|endoftext|>']
+ self.terminators = ["<|im_end|>", "<|endoftext|>"]
def init_vision_module(self):
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
- if self.config._attn_implementation == 'flash_attention_2':
- self.config.vision_config._attn_implementation = 'flash_attention_2'
+ if self.config._attn_implementation == "flash_attention_2":
+ self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not suport sdpa
- self.config.vision_config._attn_implementation = 'eager'
+ self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
- setattr(model, 'embed_dim', model.embeddings.embed_dim)
- setattr(model, 'patch_size', model.embeddings.patch_size)
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
return model
@@ -58,7 +58,7 @@ def init_resampler(self, embed_dim, vision_dim):
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
- adaptive=True
+ adaptive=True,
)
def get_input_embeddings(self):
@@ -80,11 +80,11 @@ def get_decoder(self):
return self.llm
def get_vllm_embedding(self, data):
- if 'vision_hidden_states' not in data:
+ if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.embedding_table.dtype
device = None
- tgt_sizes = data['tgt_sizes']
- pixel_values_list = data['pixel_values']
+ tgt_sizes = data["tgt_sizes"]
+ pixel_values_list = data["pixel_values"]
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
@@ -107,7 +107,16 @@ def get_vllm_embedding(self, data):
max_length_w = max([i.shape[1] for i in all_pixel_values])
for i in range(len(all_pixel_values)):
if all_pixel_values[i].shape[0] < max_length_h or all_pixel_values[i].shape[1] < max_length_w:
- all_pixel_values[i] = ops.pad(all_pixel_values[i], (0, max_length_w - all_pixel_values[i].shape[1], 0, max_length_h - all_pixel_values[i].shape[0]), value=0.0)
+ all_pixel_values[i] = ops.pad(
+ all_pixel_values[i],
+ (
+ 0,
+ max_length_w - all_pixel_values[i].shape[1],
+ 0,
+ max_length_h - all_pixel_values[i].shape[0],
+ ),
+ value=0.0,
+ )
all_pixel_values = ops.stack(all_pixel_values)
B, L, _ = all_pixel_values.shape
@@ -115,7 +124,7 @@ def get_vllm_embedding(self, data):
patch_attn_mask = ops.zeros(Tensor((B, 1, int(max_patches))), dtype=ms.bool_)
for i in range(B):
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.astype(dtype)
@@ -124,28 +133,33 @@ def get_vllm_embedding(self, data):
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
- tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
+ tmp_hs = self.vpm(
+ all_pixel_values[start_idx:end_idx],
+ patch_attention_mask=patch_attn_mask[start_idx:end_idx],
+ tgt_sizes=tgt_sizes[start_idx:end_idx],
+ ).last_hidden_state
hs.append(tmp_hs)
vision_embedding = ops.cat(hs, axis=0)
else:
- vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
+ vision_embedding = self.vpm(
+ all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
+ ).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
- vision_hidden_states.append(vision_embedding[start: start + img_cnt])
+ vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
- else: # no image
+ else: # no image
if self.training:
- dummy_image = ops.zeros(
- (1, 3, 224, 224),
- dtype=dtype
- )
- tgt_sizes = ms.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).astype(ms.int32)
+ dummy_image = ops.zeros((1, 3, 224, 224), dtype=dtype)
+ tgt_sizes = ms.Tensor(
+ [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
+ ).astype(ms.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
@@ -153,15 +167,16 @@ def get_vllm_embedding(self, data):
vision_hidden_states.append(dummy_feature)
else:
- vision_hidden_states = data['vision_hidden_states']
+ vision_hidden_states = data["vision_hidden_states"]
- if hasattr(self.llm.config, 'scale_emb'):
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
+ if hasattr(self.llm.config, "scale_emb"):
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
else:
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
- vision_hidden_states = [i.astype(vllm_embedding.dtype) if isinstance(
- i, ms.Tensor) else i for i in vision_hidden_states]
+ vision_hidden_states = [
+ i.astype(vllm_embedding.dtype) if isinstance(i, ms.Tensor) else i for i in vision_hidden_states
+ ]
# bs = len(data['input_ids'])
# for i in range(bs):
@@ -189,11 +204,7 @@ def construct(self, data, **kwargs):
with _no_grad():
return self.llm(
- input_ids=None,
- position_ids=position_ids,
- inputs_embeds=vllm_embedding,
- labels=data["labels"],
- **kwargs
+ input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, labels=data["labels"], **kwargs
)
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
@@ -203,7 +214,7 @@ def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, *
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
- **kwargs
+ **kwargs,
)
if decode_text:
return self._decode_text(output, tokenizer)
@@ -213,10 +224,10 @@ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
- 'inputs_embeds': inputs_embeds,
- 'pad_token_id': 0,
- 'eos_token_id': terminators,
- 'streamer': streamer
+ "inputs_embeds": inputs_embeds,
+ "pad_token_id": 0,
+ "eos_token_id": terminators,
+ "streamer": streamer,
}
generation_kwargs.update(kwargs)
@@ -249,7 +260,7 @@ def generate(
return_vision_hidden_states=False,
stream=False,
decode_text=False,
- **kwargs
+ **kwargs,
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
@@ -261,7 +272,7 @@ def generate(
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
- model_inputs['tgt_sizes'] = tgt_sizes
+ model_inputs["tgt_sizes"] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
@@ -274,7 +285,9 @@ def generate(
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
else:
- result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
+ result = self._decode(
+ model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
+ )
if return_vision_hidden_states:
return result, vision_hidden_states
@@ -292,11 +305,11 @@ def chat(
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
- system_prompt='',
+ system_prompt="",
stream=False,
max_slice_nums=None,
use_image_id=None,
- **kwargs
+ **kwargs,
):
if isinstance(msgs[0], list):
batched = True
@@ -319,11 +332,21 @@ def chat(
self.processor = MiniCPMVProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
processor = self.processor
- assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
- assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.query_num == processor.image_processor.image_feature_size
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.patch_size == processor.image_processor.patch_size
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.use_image_id == processor.image_processor.use_image_id
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
+ assert (
+ self.config.slice_mode == processor.image_processor.slice_mode
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
prompts_lists = []
input_images_lists = []
@@ -357,10 +380,12 @@ def chat(
msg["content"] = "\n".join(cur_msgs)
if system_prompt:
- sys_msg = {'role': 'system', 'content': system_prompt}
+ sys_msg = {"role": "system", "content": system_prompt}
copy_msgs = [sys_msg] + copy_msgs
- prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
+ prompts_lists.append(
+ processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
+ )
input_images_lists.append(images)
inputs = processor(
@@ -370,7 +395,7 @@ def chat(
use_image_id=use_image_id,
return_tensors="ms",
max_length=max_inp_length,
- image_processor=image_processor
+ image_processor=image_processor,
)
if sampling:
@@ -379,7 +404,7 @@ def chat(
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
- "repetition_penalty": 1.05
+ "repetition_penalty": 1.05,
}
else:
generation_config = {
@@ -388,11 +413,9 @@ def chat(
}
if min_new_tokens > 0:
- generation_config['min_new_tokens'] = min_new_tokens
+ generation_config["min_new_tokens"] = min_new_tokens
- generation_config.update(
- (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
- )
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
inputs.pop("image_sizes")
# with torch.inference_mode():
@@ -403,15 +426,17 @@ def chat(
vision_hidden_states=vision_hidden_states,
stream=stream,
decode_text=True,
- **generation_config
+ **generation_config,
)
if stream:
+
def stream_gen():
for text in res:
for term in self.terminators:
- text = text.replace(term, '')
+ text = text.replace(term, "")
yield text
+
return stream_gen()
else:
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
index 99d2f0ab49..5c0a147f8c 100644
--- a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -28,22 +28,21 @@
import mindspore as ms
-# import torch.utils.checkpoint
from mindspore import Parameter, Tensor, nn, ops
from mindspore.ops.operations.nn_ops import FlashAttentionScore as FlashAttention
from ...activations import ACT2FN
+from ...mindspore_adapter import recompute_except_output
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import MSPreTrainedModel
-from ...mindspore_adapter import recompute_except_output
-
# from torch.nn.init import _calculate_fan_in_and_fan_out
logger = logging.get_logger(__name__)
+
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
@@ -231,13 +230,16 @@ def trunc_normal_tf_(
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
+
def _calculate_fan_in_and_fan_out(arr):
# 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。
shape = arr.shape
dimensions = len(shape)
if dimensions < 2:
- raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
- " 2 dimensions, but got dimensions {}.".format(dimensions))
+ raise ValueError(
+ "'fan_in' and 'fan_out' can not be computed for arr with fewer than"
+ " 2 dimensions, but got dimensions {}.".format(dimensions)
+ )
if dimensions == 2: # Linear
fan_in = shape[1]
fan_out = shape[0]
@@ -251,6 +253,7 @@ def _calculate_fan_in_and_fan_out(arr):
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
@@ -335,7 +338,9 @@ def __init__(self, config: SiglipVisionConfig):
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- def construct(self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor]=None) -> ms.Tensor:
+ def construct(
+ self, pixel_values: ms.Tensor, patch_attention_mask: ms.Tensor, tgt_sizes: Optional[ms.Tensor] = None
+ ) -> ms.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values)
@@ -390,7 +395,7 @@ def __init__(self, config):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
- self.scale = self.head_dim ** -0.5
+ self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
@@ -418,8 +423,8 @@ def construct(
k_v_seq_len = key_states.shape[-2]
- query_states = ops.mul(query_states, self.scale ** 0.5)
- key_states = ops.mul(key_states, self.scale ** 0.5)
+ query_states = ops.mul(query_states, self.scale**0.5)
+ key_states = ops.mul(key_states, self.scale**0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -623,9 +628,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
- cu_seqlens_q = ops.arange(
- batch_size + 1, dtype=ms.int32
- ) # There is a memcpy here, that is very bad.
+ cu_seqlens_q = ops.arange(batch_size + 1, dtype=ms.int32) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
@@ -642,22 +645,21 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
+
class SiglipFlashAttention(SiglipAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Hack to make sure we don't use a causal mask
dropout_rate = self.dropout if self.training else 0.0
self.flash_attention = FlashAttention(
- scale_value=self.head_dim**-0.5,
- head_num=self.head_dim,
- input_layout="BSH",
- keep_prob=1-dropout_rate
+ scale_value=self.head_dim**-0.5, head_num=self.head_dim, input_layout="BSH", keep_prob=1 - dropout_rate
)
def construct(
@@ -731,9 +733,7 @@ def construct(
value_states = value_states.to(target_dtype)
# implement flash attention
- attn_output = self.flash_attention(
- query_states, key_states, value_states, None, None, None, attention_mask
- )[3]
+ attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, attention_mask)[3]
# attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
@@ -743,6 +743,7 @@ def construct(
return attn_output, attn_weights
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Cell):
def __init__(self, config):
@@ -766,11 +767,7 @@ def __init__(self, config: SiglipVisionConfig):
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention = config._attn_implementation == "flash_attention"
- self.self_attn = (
- SiglipAttention(config)
- if not self._use_flash_attention
- else SiglipFlashAttention(config)
- )
+ self.self_attn = SiglipAttention(config) if not self._use_flash_attention else SiglipFlashAttention(config)
self.layer_norm1 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
@@ -980,9 +977,8 @@ def construct(
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
- )
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
+
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
@@ -1006,11 +1002,9 @@ def __init__(self, config: SiglipVisionConfig):
# recompute
# self.encoder.recompute()
-
def get_input_embeddings(self) -> nn.Cell:
return self.embeddings.patch_embedding
-
def construct(
self,
pixel_values,
@@ -1040,14 +1034,16 @@ def construct(
dtype=ms.bool_,
)
- hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes)
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes
+ )
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not ops.any(~patch_attention_mask):
- attention_mask=None
+ attention_mask = None
else:
attention_mask = (
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
diff --git a/mindone/transformers/models/minicpm_v/processing_minicpmv.py b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
index d9da9f7882..db4ad74e9d 100644
--- a/mindone/transformers/models/minicpm_v/processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
@@ -20,9 +20,9 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
-from transformers.utils import TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.utils import TensorType
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
@@ -30,6 +30,7 @@
from ...processing_utils import ProcessorMixin
from .image_processing_minicpmv import MiniCPMVBatchFeature, MiniCPMVImageProcessor
+
class MiniCPMVProcessor(ProcessorMixin):
r"""
Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor.
@@ -61,12 +62,21 @@ def __call__(
use_image_id: bool = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
image_processor=None,
- **kwargs
+ **kwargs,
) -> MiniCPMVBatchFeature:
-
if images is not None:
- image_inputs = image_processor.preprocess(images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors)
- return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, image_processor=image_processor, **kwargs)
+ image_inputs = image_processor.preprocess(
+ images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
+ )
+ return self._convert_images_texts_to_inputs(
+ image_inputs,
+ text,
+ max_slice_nums=max_slice_nums,
+ use_image_id=use_image_id,
+ max_length=max_length,
+ image_processor=image_processor,
+ **kwargs,
+ )
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
@@ -96,13 +106,13 @@ def decode(self, *args, **kwargs):
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
result = result[1:]
- if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
+ if result[-1] == self.tokenizer.eos_id or (
+ hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
+ ):
result = result[:-1]
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
- def _convert(
- self, input_str, max_inp_length: Optional[int] = None
- ):
+ def _convert(self, input_str, max_inp_length: Optional[int] = None):
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
input_ids = self.tokenizer.encode(input_str)
else:
@@ -140,19 +150,21 @@ def _convert(
return input_ids, image_bounds
def _convert_images_texts_to_inputs(
- self,
- images,
- texts: Union[str, List[str]],
- truncation=None,
- max_length=None,
- max_slice_nums=None,
- use_image_id=None,
- return_tensors=None,
- image_processor=None,
- **kwargs
- ):
+ self,
+ images,
+ texts: Union[str, List[str]],
+ truncation=None,
+ max_length=None,
+ max_slice_nums=None,
+ use_image_id=None,
+ return_tensors=None,
+ image_processor=None,
+ **kwargs,
+ ):
if images is None or not len(images):
- model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs)
+ model_inputs = self.tokenizer(
+ texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
+ )
return MiniCPMVBatchFeature(data={**model_inputs})
pattern = "(./)"
@@ -168,33 +180,32 @@ def _convert_images_texts_to_inputs(
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
- final_text = final_text + text_chunks[i] + \
- image_processor.get_slice_image_placeholder(
- image_sizes[index][i],
- i,
- max_slice_nums,
- use_image_id
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[index][i], i, max_slice_nums, use_image_id
)
+ )
final_text += text_chunks[-1]
input_ids, image_bounds = self._convert(final_text, max_length)
input_ids_list.append(input_ids)
image_bounds_list.append(image_bounds)
- padded_input_ids, padding_lengths = self.pad(
- input_ids_list,
- padding_side="left"
- )
+ padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
for i, length in enumerate(padding_lengths):
image_bounds_list[i] = image_bounds_list[i] + length
attention_mask = padded_input_ids.ne(0)
- return MiniCPMVBatchFeature(data={
- "input_ids": padded_input_ids,
- "attention_mask": attention_mask,
- "pixel_values": images,
- "image_sizes": image_sizes,
- "image_bound": image_bounds_list,
- "tgt_sizes": tgt_sizes
- })
+ return MiniCPMVBatchFeature(
+ data={
+ "input_ids": padded_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": images,
+ "image_sizes": image_sizes,
+ "image_bound": image_bounds_list,
+ "tgt_sizes": tgt_sizes,
+ }
+ )
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
@@ -203,7 +214,6 @@ def model_input_names(self):
image_processor_input_names = MiniCPMVImageProcessor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(inputs[0], list):
@@ -232,10 +242,7 @@ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
return ops.stack([item for item in items], axis=0), [0] * batch_size
tensor = ops.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
- tensor = (
- ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
- + padding_value
- )
+ tensor = ops.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
padding_length = []
for i, item in enumerate(items):
diff --git a/mindone/transformers/models/minicpm_v/resampler.py b/mindone/transformers/models/minicpm_v/resampler.py
index 7d0d4222f3..a9ce0728c0 100644
--- a/mindone/transformers/models/minicpm_v/resampler.py
+++ b/mindone/transformers/models/minicpm_v/resampler.py
@@ -56,10 +56,10 @@ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000 ** omega # (D/2,)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
- out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
@@ -77,14 +77,14 @@ class Resampler(nn.Cell):
"""
def __init__(
- self,
- num_queries,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
- adaptive=False,
- max_size=(70, 70),
+ self,
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim=None,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
+ adaptive=False,
+ max_size=(70, 70),
):
super().__init__()
self.num_queries = num_queries
@@ -105,7 +105,7 @@ def __init__(
self.ln_kv = norm_layer((embed_dim,))
self.ln_post = norm_layer((embed_dim,))
- self.proj = Parameter((embed_dim ** -0.5) * ops.randn(embed_dim, embed_dim))
+ self.proj = Parameter((embed_dim**-0.5) * ops.randn(embed_dim, embed_dim))
self._set_2d_pos_cache(self.max_size)
@@ -125,7 +125,7 @@ def _adjust_pos_cache(self, tgt_sizes):
def _init_weights(self, m):
if isinstance(m, nn.Dense):
- trunc_normal_(m.weight, std=.02)
+ trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Dense) and m.bias is not None:
Zero(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
@@ -150,8 +150,9 @@ def construct(self, x, tgt_sizes=None):
tgt_h, tgt_w = tgt_sizes[i]
shape_0 = tgt_h * tgt_w
pos_embed.append(
- self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)) # patches * D
- key_padding_mask[i, patch_len[i]:] = True
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((int(shape_0.asnumpy()), -1)).to(dtype)
+ ) # patches * D
+ key_padding_mask[i, patch_len[i] :] = True
# FIXME how to replace torch.nn.utils.rnn.pad_sequence
# pos_embed = torch.nn.utils.rnn.pad_sequence(
@@ -160,9 +161,11 @@ def construct(self, x, tgt_sizes=None):
max_length_w = max([i.shape[1] for i in pos_embed])
for i in range(len(pos_embed)):
if pos_embed[i].shape[0] < max_length_h or pos_embed[i].shape[1] < max_length_w:
- pos_embed[i] = ops.pad(pos_embed[i], (
- 0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
- value=0.0)
+ pos_embed[i] = ops.pad(
+ pos_embed[i],
+ (0, max_length_w - pos_embed[i].shape[1], 0, max_length_h - pos_embed[i].shape[0]),
+ value=0.0,
+ )
pos_embed = ops.stack(pos_embed)
pos_embed = pos_embed.permute(1, 0, 2)
@@ -175,7 +178,8 @@ def construct(self, x, tgt_sizes=None):
q.unsqueeze(1).tile((1, bs, 1)), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
- key_padding_mask=key_padding_mask)[0]
+ key_padding_mask=key_padding_mask,
+ )[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
@@ -188,25 +192,38 @@ def _repeat(self, query, N: int):
class MultiheadAttention(nn.MultiheadAttention):
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=None):
- super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first,
- dtype)
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ dtype=None,
+ ):
+ super().__init__(
+ embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, dtype
+ )
# rewrite out_proj layer,with nn.Linear
self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias)
def construct(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
- why_not_fast_path = ''
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ why_not_fast_path = ""
# if ((attn_mask is not None and torch.is_floating_point(attn_mask))
# or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
# why_not_fast_path = "floating-point masks are not supported for fast path."
@@ -218,7 +235,7 @@ def construct(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype
+ target_type=query.dtype,
)
attn_mask = _canonical_mask(
@@ -238,12 +255,16 @@ def construct(
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ why_not_fast_path = (
+ f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ )
elif self.in_proj_weight is None:
why_not_fast_path = "in_proj_weight was None"
elif query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ why_not_fast_path = (
+ f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ )
elif self.training:
why_not_fast_path = "training is enabled"
elif (self.num_heads % 2) != 0:
@@ -304,7 +325,8 @@ def construct(
merged_mask,
need_weights,
average_attn_weights,
- mask_type)
+ mask_type,
+ )
# any_nested = query.is_nested or key.is_nested or value.is_nested
# assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
@@ -323,62 +345,84 @@ def construct(
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
training=self.training,
- key_padding_mask=key_padding_mask, need_weights=need_weights,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
- is_causal=is_causal)
+ is_causal=is_causal,
+ )
else:
attn_output, attn_output_weights = self.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
- is_causal=is_causal)
+ is_causal=is_causal,
+ )
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def multi_head_attention_forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- embed_dim_to_check: int,
- num_heads: int,
- in_proj_weight: Optional[Tensor],
- in_proj_bias: Optional[Tensor],
- bias_k: Optional[Tensor],
- bias_v: Optional[Tensor],
- add_zero_attn: bool,
- dropout_p: float,
- out_proj_weight: Tensor,
- out_proj_bias: Optional[Tensor],
- training: bool = True,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- use_separate_proj_weight: bool = False,
- q_proj_weight: Optional[Tensor] = None,
- k_proj_weight: Optional[Tensor] = None,
- v_proj_weight: Optional[Tensor] = None,
- static_k: Optional[Tensor] = None,
- static_v: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False,
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[Tensor],
+ in_proj_bias: Optional[Tensor],
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Optional[Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
# FIXME: logic passed
@@ -435,7 +479,7 @@ def multi_head_attention_forward(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
- target_type=query.dtype
+ target_type=query.dtype,
)
if is_causal and attn_mask is None:
@@ -466,18 +510,20 @@ def multi_head_attention_forward(
# longer causal.
is_causal = False
- assert embed_dim == embed_dim_to_check, \
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ assert (
+ embed_dim == embed_dim_to_check
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, ms.Tensor):
# embed_dim can be a tensor when JIT tracing
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
- assert key.shape[:2] == value.shape[:2], \
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ assert (
+ key.shape[:2] == value.shape[:2]
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
@@ -505,13 +551,15 @@ def multi_head_attention_forward(
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
+ )
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
+ )
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
@@ -539,19 +587,23 @@ def multi_head_attention_forward(
k = k.view(k.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_k.shape[0] == bsz * num_heads, \
- f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
- assert static_k.shape[2] == head_dim, \
- f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
+ assert (
+ static_k.shape[0] == bsz * num_heads
+ ), f"expecting static_k.shape[0] of {bsz * num_heads}, but got {static_k.shape[0]}"
+ assert (
+ static_k.shape[2] == head_dim
+ ), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).permute(1, 0, 2)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_v.shape[0] == bsz * num_heads, \
- f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
- assert static_v.shape[2] == head_dim, \
- f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
+ assert (
+ static_v.shape[0] == bsz * num_heads
+ ), f"expecting static_v.shape[0] of {bsz * num_heads}, but got {static_v.shape[0]}"
+ assert (
+ static_v.shape[2] == head_dim
+ ), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
v = static_v
# add zero attention along batch dimension (now first)
@@ -569,10 +621,15 @@ def multi_head_attention_forward(
# merge key padding and attention masks
if key_padding_mask is not None:
- assert key_padding_mask.shape == (bsz, src_len), \
- f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
- key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
- broadcast_to((-1, num_heads, -1, -1)).reshape(bsz * num_heads, 1, src_len)
+ assert key_padding_mask.shape == (
+ bsz,
+ src_len,
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = (
+ key_padding_mask.view(bsz, 1, 1, src_len)
+ .broadcast_to((-1, num_heads, -1, -1))
+ .reshape(bsz * num_heads, 1, src_len)
+ )
if attn_mask is None:
attn_mask = key_padding_mask
else:
@@ -641,8 +698,14 @@ def multi_head_attention_forward(
return attn_output, None
-def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
- key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
+def _mha_shape_check(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor],
+ num_heads: int,
+):
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
# and returns if the input is batched or not.
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
@@ -651,58 +714,65 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
if query.dim() == 3:
# Batched Inputs
is_batched = True
- assert key.dim() == 3 and value.dim() == 3, \
- ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ assert key.dim() == 3 and value.dim() == 3, (
+ "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
+ )
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 2, \
- ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
- f" but found {key_padding_mask.dim()}-D tensor instead")
+ assert key_padding_mask.dim() == 2, (
+ "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead"
+ )
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), \
- ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead")
+ assert attn_mask.dim() in (2, 3), (
+ "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead"
+ )
elif query.dim() == 2:
# Unbatched Inputs
is_batched = False
- assert key.dim() == 2 and value.dim() == 2, \
- ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ assert key.dim() == 2 and value.dim() == 2, (
+ "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
+ )
if key_padding_mask is not None:
- assert key_padding_mask.dim() == 1, \
- ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
- f" but found {key_padding_mask.dim()}-D tensor instead")
+ assert key_padding_mask.dim() == 1, (
+ "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead"
+ )
if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), \
- ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead")
+ assert attn_mask.dim() in (2, 3), (
+ "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead"
+ )
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
- assert attn_mask.shape == expected_shape, \
- (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
+ assert (
+ attn_mask.shape == expected_shape
+ ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
else:
raise AssertionError(
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
+ )
return is_batched
def _canonical_mask(
- mask: Optional[Tensor],
- mask_name: str,
- other_type: Optional,
- other_name: str,
- target_type: None,
- check_other: bool = True,
+ mask: Optional[Tensor],
+ mask_name: str,
+ other_type: Optional,
+ other_name: str,
+ target_type: None,
+ check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = ops.is_floating_point(mask)
if _mask_dtype != ms.bool_ and not _mask_is_float:
- raise AssertionError(
- f"only bool and floating types of {mask_name} are supported")
+ raise AssertionError(f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
@@ -710,10 +780,7 @@ def _canonical_mask(
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
- mask = (
- ops.zeros_like(mask, dtype=target_type)
- .masked_fill(mask, float("-inf"))
- )
+ mask = ops.zeros_like(mask, dtype=target_type).masked_fill(mask, float("-inf"))
return mask
@@ -726,11 +793,11 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional:
def _in_projection_packed(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w: Tensor,
- b: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w: Tensor,
+ b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
@@ -786,15 +853,15 @@ def _in_projection_packed(
def _in_projection(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w_q: Tensor,
- w_k: Tensor,
- w_v: Tensor,
- b_q: Optional[Tensor] = None,
- b_k: Optional[Tensor] = None,
- b_v: Optional[Tensor] = None,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w_q: Tensor,
+ w_k: Tensor,
+ w_v: Tensor,
+ b_q: Optional[Tensor] = None,
+ b_k: Optional[Tensor] = None,
+ b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Performs the in-projection step of the attention operation. This is simply
diff --git a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
index f6d84bd25c..110ebcaa76 100644
--- a/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
+++ b/mindone/transformers/models/minicpm_v/tokenization_minicpmv_fast.py
@@ -57,7 +57,7 @@ def im_id_end_id(self):
@property
def newline_id(self):
- return self.convert_tokens_to_ids('\n')
+ return self.convert_tokens_to_ids("\n")
@staticmethod
def escape(text: str) -> str:
@@ -67,4 +67,5 @@ def escape(text: str) -> str:
def unescape(text: str) -> str:
return text
+
AutoTokenizer.register("MiniCPMVTokenizerFast", MiniCPMVTokenizerFast)
diff --git a/mindone/transformers/models/qwen2/__init__.py b/mindone/transformers/models/qwen2/__init__.py
index be2d5916fd..aa7e109cf0 100644
--- a/mindone/transformers/models/qwen2/__init__.py
+++ b/mindone/transformers/models/qwen2/__init__.py
@@ -21,7 +21,6 @@
}
-
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
_import_structure["modeling_qwen2"] = [
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
index e1eb9df146..f5feb36cbe 100644
--- a/mindone/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -19,18 +19,19 @@
# limitations under the License.
"""Mindspore Qwen2 model."""
-import math
from typing import List, Optional, Tuple, Union
import numpy as np
+from transformers import logging
import mindspore as ms
-from mindspore import nn, ops, Tensor, Parameter
+from mindspore import Parameter, Tensor, nn, ops
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -40,11 +41,6 @@
from ...modeling_utils import MSPreTrainedModel
from .configuration_qwen2 import Qwen2Config
-from transformers import logging
-
-from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
logger = logging.get_logger(__name__)
@@ -70,7 +66,6 @@ def dtype_to_min(dtype):
raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}")
-
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: ms.Tensor,
@@ -126,9 +121,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = causal_mask.masked_fill(padding_mask, min_dtype)
else:
causal_mask = ops.cat(
- [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
- ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)],
- axis=-1
+ [
+ ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype),
+ ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length),
+ ],
+ axis=-1,
)
return causal_mask
@@ -167,9 +164,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.inv_freq = inv_freq
# Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=None, dtype=ms.float32
- )
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=None, dtype=ms.float32)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
@@ -300,7 +295,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
base=self.rope_theta,
)
- self.scale = self.head_dim ** -0.5
+ self.scale = self.head_dim**-0.5
def construct(
self,
@@ -342,8 +337,8 @@ def construct(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- query_states = ops.mul(query_states, self.scale ** 0.5)
- key_states = ops.mul(key_states, self.scale ** 0.5)
+ query_states = ops.mul(query_states, self.scale**0.5)
+ key_states = ops.mul(key_states, self.scale**0.5)
attn_weights = ops.matmul(query_states, key_states.swapaxes(2, 3))
@@ -859,7 +854,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -907,9 +901,7 @@ def construct(
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = ops.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
- )
+ cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -1036,11 +1028,7 @@ def _update_causal_mask(
batch_size=input_tensor.shape[0],
)
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and not output_attentions
- ):
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and not output_attentions:
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
@@ -1079,7 +1067,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1191,7 +1178,7 @@ def prepare_inputs_for_generation(
# input_ids = input_ids[:, :cache_position.shape[0]]
if inputs_embeds is not None: # Exception 1
if 0 not in input_ids.shape:
- input_ids = input_ids[:, -cache_position.shape[0]:]
+ input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = ops.index_select(input_ids, -1, cache_position)
@@ -1245,7 +1232,6 @@ def prepare_inputs_for_generation(
return model_inputs
-
class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1262,7 +1248,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
-
def construct(
self,
input_ids: ms.Tensor = None,
@@ -1353,7 +1338,6 @@ def construct(
)
-
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
@@ -1378,7 +1362,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
-
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -1429,4 +1412,4 @@ def construct(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- )
\ No newline at end of file
+ )
diff --git a/mindone/transformers/models/qwen2/tokenization_qwen2.py b/mindone/transformers/models/qwen2/tokenization_qwen2.py
index c5cff300a2..b13046fbbf 100644
--- a/mindone/transformers/models/qwen2/tokenization_qwen2.py
+++ b/mindone/transformers/models/qwen2/tokenization_qwen2.py
@@ -49,9 +49,7 @@ def bytes_to_unicode():
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
- bs = (
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
- )
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
From aa9e3fc787d9bb6fca73eb7dc69906f45b5e8817 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Wed, 25 Dec 2024 15:02:54 +0800
Subject: [PATCH 6/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
examples/minicpm_v/finetune/dataset.py | 34 +-----
examples/minicpm_v/finetune/finetune.py | 98 +--------------
.../transformers/image_processing_utils.py | 2 +-
mindone/transformers/image_transforms.py | 13 +-
mindone/transformers/image_utils.py | 3 +
.../minicpm_v/image_processing_minicpmv.py | 5 +-
.../models/minicpm_v/modeling_minicpmv.py | 13 +-
.../models/minicpm_v/modeling_navit_siglip.py | 114 ++----------------
.../models/minicpm_v/processing_minicpmv.py | 4 +-
.../models/minicpm_v/resampler.py | 26 ++--
.../models/qwen2/modeling_qwen2.py | 22 ++--
11 files changed, 60 insertions(+), 274 deletions(-)
diff --git a/examples/minicpm_v/finetune/dataset.py b/examples/minicpm_v/finetune/dataset.py
index eb82b9056b..775f563986 100644
--- a/examples/minicpm_v/finetune/dataset.py
+++ b/examples/minicpm_v/finetune/dataset.py
@@ -1,35 +1,19 @@
import copy
-import json
import logging
import math
-import os
import random
import re
-import sys
-from dataclasses import dataclass, field
-from typing import Dict, List, Optional
+from typing import Dict
import numpy as np
-from datasets import load_dataset
from PIL import Image
-from transformers import AutoTokenizer
-
-import mindspore as ms
-from mindspore import ops
-
-# from torch.nn.utils.rnn import pad_sequence
-from mindspore.dataset import Dataset
-
-mindone_lib_path = os.path.abspath(os.path.abspath("../../../"))
-sys.path.insert(0, mindone_lib_path)
-
-import logging
-
-from mindone.transformers.models.minicpm_v2_6.processing_minicpmv import MiniCPMVProcessor
logger = logging.getLogger(__name__)
-llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
+llama3_chat_template = (
+ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'"
+ "+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
+)
class SupervisedDataset:
@@ -69,7 +53,7 @@ def __getitem__(self, idx, retry_count=3):
if isinstance(self.raw_data[idx]["image"], str):
images_dict = {"": Image.open(self.raw_data[idx]["image"]).convert("RGB")}
elif isinstance(self.raw_data[idx]["image"], Dict):
- ### for multi-images input, the template for every image is , such as ,
+ # for multi-images input, the template for every image is , such as ,
images_dict = {
img_name: Image.open(img_path).convert("RGB")
for img_name, img_path in self.raw_data[idx]["image"].items()
@@ -111,11 +95,6 @@ def __getitem__(self, idx, retry_count=3):
# If max retries reached, return a blank or default item
logger.warning("Max retries reached. Returning a blank entry.")
return None
-
- # except:
- # logger.error(f"data fetch error")
- # # return self.__getitem__(random.randint(0, len(self)))
- # return (ret["input_ids"], ret["position_ids"], ret["labels"], np.ones_like(ret["input_ids"], dtype=np.bool_), ret["pixel_values"], ret["tgt_sizes"], ret["image_bound"])
return ret
@@ -327,7 +306,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
raw_msg += prefix + message
assert set([i["role"] for i in chat]) & set(["assistant"])
- ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
input_ids = np.array(input_ids)
diff --git a/examples/minicpm_v/finetune/finetune.py b/examples/minicpm_v/finetune/finetune.py
index a40502d30c..f6465f6d3c 100644
--- a/examples/minicpm_v/finetune/finetune.py
+++ b/examples/minicpm_v/finetune/finetune.py
@@ -1,38 +1,17 @@
-import glob
import json
-import logging
import os
import sys
from dataclasses import dataclass, field
-from functools import partial
-from types import MethodType
-from typing import Dict, List, Literal, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple
import numpy as np
-
-import mindspore as ms
-from mindspore import Parameter, Tensor, context, dataset, nn, ops
-from mindspore.communication.management import get_group_size, get_rank, init
-from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list
-
-# init()
-# rank, rank_size, parallel_mode = get_rank(), get_group_size(), context.ParallelMode.DATA_PARALLEL
-# context.set_auto_parallel_context(
-# device_num=rank_size, parallel_mode=parallel_mode, gradients_mean=True
-# )
-
-rank, rank_size = 0, 1
-
-ms.set_context(
- mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True, mempool_block_size="59GB", max_device_memory="59GB"
-)
-
import transformers
from transformers import HfArgumentParser
+import mindspore as ms
+from mindspore import nn
from mindspore.dataset import transforms, vision
-
-# from accelerate.utils import DistributedType
+from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list
mindone_lib_path = os.path.abspath(os.path.abspath("../../../"))
sys.path.insert(0, mindone_lib_path)
@@ -45,14 +24,6 @@
from mindone.transformers.trainer import Trainer
from mindone.transformers.training_args import TrainingArguments
-# from transformers.integrations import deepspeed
-
-
-# from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
-
-# ms.set_context(mode=ms.context.PYNATIVE_MODE, pynative_synchronize=True)
-# ms.set_context(mode=ms.context.PYNATIVE_MODE)
-
@dataclass
class ModelArguments:
@@ -65,25 +36,6 @@ class DataArguments:
eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})
-# @dataclass
-# class TrainingArguments(TrainingArguments):
-# cache_dir: Optional[str] = field(default=None)
-# optim: str = field(default="adamw_mindspore")
-# model_max_length: int = field(
-# default=2048,
-# metadata={
-# "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
-# },
-# )
-# tune_vision: Optional[bool] = field(default=True)
-# tune_llm: Optional[bool] = field(default=True)
-# llm_type: str = field(default="minicpm")
-# use_lora: Optional[bool] = field(default=False)
-# max_slice_nums: Optional[int] = field(default=9)
-# distributed: Optional[bool] = field(default=False)
-# amp_level: Optional[str] = field(default="O0")
-
-
@dataclass
class LoraArguments:
lora_r: int = 64
@@ -244,7 +196,7 @@ def build_transform():
def get_parameter_number(model):
- trainable_params, all_param = 0, 0
+ trainable_params = 0
# for param in model.parameters():
# num_params = param.numel()
# # if using DS Zero 3 and the weights are initialized empty
@@ -290,13 +242,6 @@ def train():
# data_args.rank, data_args.rank_size, parallel_mode = 0, 1, None
local_rank = training_args.local_rank
- world_size = int(os.environ.get("WORLD_SIZE", 1))
- ddp = world_size != 1
- device_map = None
- if lora_args.q_lora:
- device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
- if len(training_args.fsdp) > 0:
- logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")
model = MiniCPMV_v2_6.from_pretrained(
model_args.model_name_or_path,
@@ -332,39 +277,6 @@ def train():
for param in model.llm.trainable_params():
param.requires_grad = False
- if training_args.use_lora:
- if training_args.use_lora and training_args.tune_llm:
- raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
-
- rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
- for name, param in model.llm.named_parameters():
- param.requires_grad = False
- modules_to_save = ["embed_tokens", "resampler"]
- if training_args.tune_vision:
- modules_to_save.append("vpm")
- lora_config = LoraConfig(
- r=lora_args.lora_r,
- lora_alpha=lora_args.lora_alpha,
- target_modules=lora_args.lora_target_modules,
- lora_dropout=lora_args.lora_dropout,
- bias=lora_args.lora_bias,
- layers_to_transform=lora_args.lora_layers_to_transform,
- modules_to_save=modules_to_save,
- )
- if not hasattr(model, "get_input_embeddings"):
-
- def get_input_embeddings(self):
- return self.llm.get_input_embeddings()
-
- model.get_input_embeddings = MethodType(get_input_embeddings, model)
- if lora_args.q_lora:
- model = prepare_model_for_kbit_training(
- model, use_gradient_checkpointing=training_args.gradient_checkpointing
- )
- model = get_peft_model(model, lora_config)
- if training_args.gradient_checkpointing:
- model.enable_input_require_grads()
-
rank0_print(get_parameter_number(model))
llm_type = training_args.llm_type
diff --git a/mindone/transformers/image_processing_utils.py b/mindone/transformers/image_processing_utils.py
index 194c99bf9a..170dd2a8ad 100644
--- a/mindone/transformers/image_processing_utils.py
+++ b/mindone/transformers/image_processing_utils.py
@@ -538,7 +538,7 @@ class BaseImageProcessor(ImageProcessingMixin):
__call__(self, images, **kwargs) -> BatchFeature: Preprocess an image or a batch of images.
preprocess(self, images, **kwargs) -> BatchFeature: Abstract method to be implemented by concrete image processors.
rescale(self, image, scale, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Rescale an image by a scale factor.
- normalize(self, image, mean, std, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Normalize an image using mean and standard deviation.
+ normalize(self, image, mean, std, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Normalize an image.
center_crop(self, image, size, data_format=None, input_data_format=None, **kwargs) -> np.ndarray: Center crop an image to a specified size.
"""
diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py
index ee7ed99897..adc81c69c1 100644
--- a/mindone/transformers/image_transforms.py
+++ b/mindone/transformers/image_transforms.py
@@ -1,19 +1,12 @@
import warnings
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, Optional, Tuple, Union
import numpy as np
import PIL
import mindspore
-from mindspore import ops
-
-from .image_utils import (
- ChannelDimension,
- ImageInput,
- get_channel_dimension_axis,
- get_image_size,
- infer_channel_dimension_format,
-)
+
+from .image_utils import ChannelDimension, get_channel_dimension_axis, get_image_size, infer_channel_dimension_format
def to_channel_dimension_format(
diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py
index 87858e2e60..152394461b 100644
--- a/mindone/transformers/image_utils.py
+++ b/mindone/transformers/image_utils.py
@@ -1,6 +1,9 @@
from typing import List, Optional, Tuple, Union
import numpy as np
+import PIL
+
+import mindspore
from .utils.generic import ExplicitEnum
diff --git a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
index 1aad29b744..96d60720c1 100644
--- a/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/image_processing_minicpmv.py
@@ -9,18 +9,15 @@
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
ChannelDimension,
- ImageInput,
infer_channel_dimension_format,
- is_batched,
is_torch_tensor,
- make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import ops
from ...image_processing_utils import BaseImageProcessor, BatchFeature
diff --git a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
index 71cd628779..c8c27b0e5f 100644
--- a/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/modeling_minicpmv.py
@@ -2,13 +2,12 @@
import math
from copy import deepcopy
from threading import Thread
-from typing import List, Optional
from PIL import Image
from transformers import TextIteratorStreamer
import mindspore as ms
-from mindspore import Parameter, Tensor, _no_grad, nn, ops
+from mindspore import Tensor, ops
from ..qwen2 import Qwen2ForCausalLM, Qwen2PreTrainedModel
from .configuration_minicpm import MiniCPMVConfig
@@ -17,8 +16,6 @@
from .processing_minicpmv import MiniCPMVProcessor
from .resampler import Resampler
-# from .tokenization_minicpmv_fast import MiniCPMVTokenizerFast
-
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
@@ -82,7 +79,6 @@ def get_decoder(self):
def get_vllm_embedding(self, data):
if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.embedding_table.dtype
- device = None
tgt_sizes = data["tgt_sizes"]
pixel_values_list = data["pixel_values"]
vision_hidden_states = []
@@ -202,10 +198,9 @@ def construct(self, data, **kwargs):
if position_ids.dtype != ms.int64:
position_ids = position_ids.long()
- with _no_grad():
- return self.llm(
- input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, labels=data["labels"], **kwargs
- )
+ return self.llm(
+ input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, labels=data["labels"], **kwargs
+ )
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
diff --git a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
index 5c0a147f8c..cf88c5da87 100644
--- a/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
+++ b/mindone/transformers/models/minicpm_v/modeling_navit_siglip.py
@@ -20,26 +20,20 @@
import os
import warnings
from dataclasses import dataclass
-from typing import Any, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
-import numpy as np
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
import mindspore as ms
-
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import nn, ops
from mindspore.ops.operations.nn_ops import FlashAttentionScore as FlashAttention
from ...activations import ACT2FN
-from ...mindspore_adapter import recompute_except_output
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import MSPreTrainedModel
-# from torch.nn.init import _calculate_fan_in_and_fan_out
-
-
logger = logging.get_logger(__name__)
@@ -174,12 +168,12 @@ def norm_cdf(x):
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1)
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
@@ -507,7 +501,8 @@ def construct(
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
+ # We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.swapaxes(1, 2)
key_states = key_states.swapaxes(1, 2)
@@ -553,98 +548,6 @@ def construct(
return attn_output, attn_weights
- def _flash_attention_forward(
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`int`, *optional*):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- """
-
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask, query_length
- )
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
- else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
- )
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis(
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
- )
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = ops.arange(batch_size + 1, dtype=ms.int32) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
class SiglipFlashAttention(SiglipAttention):
"""
@@ -697,7 +600,8 @@ def construct(
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
+ # We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
# query_states = query_states.swapaxes(1, 2)
# key_states = key_states.swapaxes(1, 2)
diff --git a/mindone/transformers/models/minicpm_v/processing_minicpmv.py b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
index db4ad74e9d..a78d4e0f23 100644
--- a/mindone/transformers/models/minicpm_v/processing_minicpmv.py
+++ b/mindone/transformers/models/minicpm_v/processing_minicpmv.py
@@ -17,7 +17,7 @@
"""
import re
-from typing import Any, Dict, List, Optional, Union
+from typing import List, Optional, Union
import numpy as np
from transformers.image_utils import ImageInput
@@ -25,7 +25,7 @@
from transformers.utils import TensorType
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Tensor, ops
from ...processing_utils import ProcessorMixin
from .image_processing_minicpmv import MiniCPMVBatchFeature, MiniCPMVImageProcessor
diff --git a/mindone/transformers/models/minicpm_v/resampler.py b/mindone/transformers/models/minicpm_v/resampler.py
index a9ce0728c0..b8fed5887e 100644
--- a/mindone/transformers/models/minicpm_v/resampler.py
+++ b/mindone/transformers/models/minicpm_v/resampler.py
@@ -10,11 +10,7 @@
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import One
from mindspore.common.initializer import TruncatedNormal as trunc_normal_
-from mindspore.common.initializer import XavierNormal as xavier_normal_
-from mindspore.common.initializer import XavierUniform as xavier_uniform_
-from mindspore.common.initializer import Zero, initializer
-from mindspore.mint.nn.functional import *
-from mindspore.nn.layer.activation import *
+from mindspore.common.initializer import Zero
def get_2d_sincos_pos_embed(embed_dim, image_size):
@@ -286,15 +282,15 @@ def construct(
# why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
- tensor_args = (
- query,
- key,
- value,
- self.in_proj_weight,
- self.in_proj_bias,
- self.out_proj.weight,
- self.out_proj.bias,
- )
+ # tensor_args = (
+ # query,
+ # key,
+ # value,
+ # self.in_proj_weight,
+ # self.in_proj_bias,
+ # self.out_proj.weight,
+ # self.out_proj.bias,
+ # )
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
# FIXME logic is passed.
@@ -424,7 +420,7 @@ def multi_head_attention_forward(
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
- tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
+ # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
# FIXME: logic passed
# if has_torch_function(tens_ops):
# return handle_torch_function(
diff --git a/mindone/transformers/models/qwen2/modeling_qwen2.py b/mindone/transformers/models/qwen2/modeling_qwen2.py
index f5feb36cbe..929bfa432c 100644
--- a/mindone/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindone/transformers/models/qwen2/modeling_qwen2.py
@@ -25,11 +25,11 @@
from transformers import logging
import mindspore as ms
-from mindspore import Parameter, Tensor, nn, ops
+from mindspore import Parameter, nn, ops
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
-from ...cache_utils import Cache, StaticCache
+from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
@@ -87,7 +87,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
@@ -388,7 +389,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
+ # that was made default for flash_attn>=2.1.
+ # This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = False
@@ -546,8 +549,10 @@ def construct(
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
- "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
+ "Falling back to the manual attention implementation, but specifying the manual implementation "
+ "will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
@@ -1189,7 +1194,10 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`,
+ # as otherwise the input `position_ids` would have various stride during the decoding.
+ # Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case,
+ # `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
From fdf25812e51f232acc16b93d1786c0a83df77853 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Fri, 27 Dec 2024 14:42:05 +0800
Subject: [PATCH 7/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
mindone/transformers/feature_extraction_utils.py | 5 -----
mindone/transformers/image_processing_utils.py | 5 -----
2 files changed, 10 deletions(-)
diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py
index 4c8be548c5..08b9c3d2bf 100644
--- a/mindone/transformers/feature_extraction_utils.py
+++ b/mindone/transformers/feature_extraction_utils.py
@@ -596,9 +596,4 @@ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
- import mindnlp.transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
cls._auto_class = auto_class
diff --git a/mindone/transformers/image_processing_utils.py b/mindone/transformers/image_processing_utils.py
index 170dd2a8ad..a968d90100 100644
--- a/mindone/transformers/image_processing_utils.py
+++ b/mindone/transformers/image_processing_utils.py
@@ -492,11 +492,6 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
- import mindnlp.transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
cls._auto_class = auto_class
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
From f0f908070a887aa5028aaa2a982aec20e5dbda0e Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Fri, 27 Dec 2024 15:00:17 +0800
Subject: [PATCH 8/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
mindone/transformers/processing_utils.py | 23 +++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)
diff --git a/mindone/transformers/processing_utils.py b/mindone/transformers/processing_utils.py
index adc6cc1222..4fbf59bcce 100644
--- a/mindone/transformers/processing_utils.py
+++ b/mindone/transformers/processing_utils.py
@@ -18,8 +18,11 @@
Processing saving/loading class for common processors.
"""
+import importlib
import os
+import sys
import warnings
+from pathlib import Path
from typing import Optional, Union
import transformers
@@ -89,10 +92,12 @@ def __init__(self, *args, **kwargs):
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
if isinstance(class_name, tuple):
proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
- elif class_name == "MiniCPMVImageProcessor":
- from mindone.transformers import MiniCPMVImageProcessor
-
- proper_class = MiniCPMVImageProcessor
+ elif "ImageProcessor" in class_name:
+ sub_path = os.path.abspath(os.path.dirname(__file__))
+ sub_path = str(Path(sub_path).parent)
+ sys.path.insert(0, sub_path)
+ module_name = importlib.import_module("mindone.transformers")
+ proper_class = getattr(module_name, class_name)
else:
proper_class = getattr(transformers_module, class_name)
@@ -270,10 +275,12 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
attribute_class = classes[1]
else:
attribute_class = classes[0]
- elif class_name == "MiniCPMVImageProcessor":
- from mindone.transformers import MiniCPMVImageProcessor
-
- attribute_class = MiniCPMVImageProcessor
+ elif "ImageProcessor" in class_name:
+ sub_path = os.path.abspath(os.path.dirname(__file__))
+ sub_path = str(Path(sub_path).parent)
+ sys.path.insert(0, sub_path)
+ module_name = importlib.import_module("mindone.transformers")
+ attribute_class = getattr(module_name, class_name)
else:
attribute_class = getattr(transformers_module, class_name)
From d7c7b45cd531564bc0a25e216c93290ad72ffc29 Mon Sep 17 00:00:00 2001
From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com>
Date: Fri, 27 Dec 2024 15:13:23 +0800
Subject: [PATCH 9/9] feat(minicpm-v): Support MiniCPM-V Training pipeline
---
examples/minicpm_v/finetune/finetune.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/minicpm_v/finetune/finetune.py b/examples/minicpm_v/finetune/finetune.py
index f6465f6d3c..b726ece00d 100644
--- a/examples/minicpm_v/finetune/finetune.py
+++ b/examples/minicpm_v/finetune/finetune.py
@@ -20,7 +20,7 @@
from transformers import AutoTokenizer
from mindone.transformers.mindspore_adapter import MindSporeArguments
-from mindone.transformers.models.minicpm_v2_6 import MiniCPMV_v2_6
+from mindone.transformers.models.minicpm_v import MiniCPMV_v2_6
from mindone.transformers.trainer import Trainer
from mindone.transformers.training_args import TrainingArguments