Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement of roformerv2 #2145

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

implement of roformerv2 #2145

wants to merge 13 commits into from

Conversation

pass-lin
Copy link
Contributor

@pass-lin pass-lin commented Mar 16, 2025

from issue 2118
I've provided an implementation of RoFormer V2 in Keras_Hub. It's worth noting that this is a model implemented natively in Keras 2.3.1+TF 1.15. Therefore, I tried to reuse the original author's implementation as much as possible when implementing it.

The current Keras_hub lacks Chinese models that support Flash-Attention and are not limited to the length of the encoder-only. We believe that RoFormer V2 can to some extent make up for this.

Since this is a native Keras model, and HF doesn't provide an implementation, I've additionally provided a weight file.
roformerV2-small
roformerV2-base
roformerV2-large

The accuracy difference between the version I implemented and the original one is kept below 1e-5. Moreover, the LN layer in the original version is quite special, so I used an RMS layer initialized as ones as a substitute. They are equivalent.

Finally, I ran the following scripts to test whether my implementation is complete. The results show that they all run smoothly.
But I don't quite understand how to convert them into test cases.

import os

os.environ['KERAS_BACKEND'] = 'torch'
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import keras_hub
import keras
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]
model_path= "modelscope://q935499957/roformerV2_small_zh-Keras"
# Pretrained classifier.
classifier = keras_hub.models.RorformerV2Classifier.from_preset(
    model_path,
    num_classes=4,
)
classifier.predict(x=features, batch_size=2)
classifier.fit(x=features, y=labels, batch_size=2)


# Re-compile (e.g., with a new learning rate).
classifier.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    jit_compile=False,
)
# Access backbone programmatically (e.g., to change `trainable`).
classifier.backbone.trainable = False
# Fit again.
classifier.fit(x=features, y=labels, batch_size=2)
import numpy as np
features = {
    "token_ids": np.ones(shape=(2, 12), dtype="int32"),
    "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2),
    "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2),
}
labels = [0, 3]

# Pretrained classifier without preprocessing.
classifier = keras_hub.models.BertClassifier.from_preset(
    model_path,
    num_classes=4,
    preprocessor=None,
)
classifier.fit(x=features, y=labels, batch_size=2)

features = ["The quick brown fox jumped.", "I forgot my homework."]

# Pretrained language model.
masked_lm = keras_hub.models.RoformerV2MaskedLM.from_preset(
    model_path,
)
masked_lm.fit(x=features, batch_size=2)

# Re-compile (e.g., with a new learning rate).
masked_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    jit_compile=False,
)
# Access backbone programmatically (e.g., to change `trainable`).
masked_lm.backbone.trainable = False
# Fit again.
masked_lm.fit(x=features, batch_size=2)

tokenizer = keras_hub.tokenizers.RoformerV2Tokenizer.from_preset(
    model_path,
)
test_string = "中文文本"

print(test_string==tokenizer.detokenize(tokenizer(test_string)).replace(" ",""))

@pass-lin
Copy link
Contributor Author

image

Aother, i am author of pr.
I found that there was an error message in the last PR that I didn't include. So this time, I submitted it together with the PR.

@mattdangerw mattdangerw self-requested a review March 18, 2025 20:58
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pass-lin ! Nice contribution.

Please do add some testing for this, you can use bert as an example. And update those tests as needed till pytest passes. You can skip anything related to testing from_preset (these are usually marked with a "large" marked on the tests), we will need uploaded weights to kaggle before we can add those. That can be a follow up I think.

self.max_wavelength = max_wavelength
self.output_dim = output_dim

def call(self, tensors: list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally remove type annotations, we aren't using them at least yet!

return tensor[indices]

def sinusoidal_embeddings(self, pos, dim, base=10000):
assert dim % 2 == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer valueerror with a readable message over asserts.


@keras_hub_export("keras_hub.models.RorformerV2Backbone")
class RoformerV2Backbone(Backbone):
"""A BERT encoder network.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update from bert!

x = self.embeddings_dropout(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, attention_mask=padding_mask_input)
# Construct the two BERT outputs. The pooled output is a dense layer on
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just find replace bert -> roformer (will stop commenting on this this one)

return y

def get_config(self):
config = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we've switched to another style for the most part. config = super().get_config() -> config.update({...}) -> return.

@keras_hub_export(
[
"keras_hub.models.RoformerV2TextClassifierPreprocessor",
"keras_hub.models.RoformerV2Preprocessor",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can ditch this second export name. we only had it in bert for backwards compat.

@keras_hub_export(
[
"keras_hub.models.RorformerV2TextClassifier",
"keras_hub.models.RorformerV2Classifier",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, ditch this name, we only needed it for backwards compat

@pass-lin
Copy link
Contributor Author

@mattdangerw
I need your help to start the test. I'm not sure if my test is correct. I can't run pytest locally.

@mattdangerw
Copy link
Member

@pass-lin approved running the tests! Though finding a way to run some CPU testing locally will probably speed things up greatly, even if you use our CI for GPU tests. What's you local development machine?

@pass-lin
Copy link
Contributor Author

@pass-lin approved running the tests! Though finding a way to run some CPU testing locally will probably speed things up greatly, even if you use our CI for GPU tests. What's you local development machine?

I've already been able to successfully run and pass the tests locally. But my model relies on rms ln, which was newly added in keras3.9. Now the keras version for online testing is not so new, so it will cause the test to fail.

@pass-lin
Copy link
Contributor Author

@pass-lin approved running the tests! Though finding a way to run some CPU testing locally will probably speed things up greatly, even if you use our CI for GPU tests. What's you local development machine?

I've already been able to successfully run and pass the tests locally. But my model relies on rms ln, which was newly added in keras3.9. Now the keras version for online testing is not so new, so it will cause the test to fail.

I've made some changes, removing the dependency on RMS norm and directly copying an implementation of llama norm.

@pass-lin
Copy link
Contributor Author

I'm not quite sure how to run jax tests locally. My current method modifies the keras source code to make the default backend jax. In this case I can pass the test locally
@mattdangerw

@mattdangerw
Copy link
Member

mattdangerw commented Mar 21, 2025

@pass-lin two ways...

Either go CPU only, pip install requirments.txt (I'd recommend in conda/pyenv/virtualenv). export KERAS_BACKEND=jax, locally run pytest. For largest tests (saved model, anything hitting network, you will pytest --run_large.

Or do the same with GPU support, pip install -r requirements-jax-cuda.txt.

@pass-lin
Copy link
Contributor Author

I accidentally entered an extra max_sequence_length parameter in the test script. This exists in BERT, but not in RoFormer. I've already deleted it. This seems to be the reason for the failure of JAX online testing.
But strangely, tf and torch work fine online. And so does jax locally.
@mattdangerw

@pass-lin
Copy link
Contributor Author

@mattdangerw
Could you please help me start the online test?

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Mar 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 24, 2025
@pass-lin
Copy link
Contributor Author

pass-lin commented Mar 25, 2025

I might need some help. The error in JAX is due to a version issue. Do I need to implement a native attention mechanism for Keras 3.5 and below?
And I can't reproduce this error locally with Torch. I can pass this test at mt local when using GPU and torch backend.And I don't seem to understand the source of this error. Could you please give me some help?
@mattdangerw

@pass-lin
Copy link
Contributor Author

@mattdangerw
I saw this error mentioned in issue#138317. Here is the minimal code to reproduce a similar error.

import torch
from torch import Tensor, nn


class SelfAttention(nn.Module):
    def __init__(
        self,
        num_attention_heads: int = 12,
        hidden_size: int = 768,
        attention_probs_dropout_prob: float = 0.1,
    ):
        super().__init__()

        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        self.dropout_prob = attention_probs_dropout_prob

    def transpose_for_scores(self, x: Tensor) -> Tensor:
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        return x.view(new_x_shape).permute(0, 2, 1, 3)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_layer,
            key_layer,
            value_layer,
            attn_mask=attention_mask,
            dropout_p=self.dropout_prob if self.training else 0.0,
            is_causal=False,
        )
        return attn_output


def test_attention():
    device = torch.device("cuda")
    num_attention_heads = 8
    hidden_size = 512
    attention_probs_dropout_prob = 0.0
    model = SelfAttention(
        num_attention_heads=num_attention_heads,
        hidden_size=hidden_size,
        attention_probs_dropout_prob=attention_probs_dropout_prob,
    ).to(device)

    model = torch.compile(model)
    batch_size = 8
    length = 1
    inputs_embeds = torch.randn(batch_size, length, hidden_size, device=device)
    attention_mask = torch.ones(batch_size, 1, length, length, device=device)
    attn_output = model(hidden_states=inputs_embeds, attention_mask=attention_mask)[0]
    loss = attn_output.mean()
    loss.backward()


test_attention()

It seems to be a problem caused by torch.compile. So is it possible that the bug in our case is also caused by torch.compile?

@pass-lin
Copy link
Contributor Author

pass-lin commented Mar 26, 2025

my env pip list

Package                      Version      Editable project location
---------------------------- ------------ -------------------------
absl-py                      2.2.0
array_record                 0.5.1
astor                        0.8.1
astunparse                   1.6.3
build                        1.2.2.post1
certifi                      2025.1.31
charset-normalizer           3.4.1
click                        8.1.8
coverage                     7.7.1
dm-tree                      0.1.8
etils                        1.5.2
exceptiongroup               1.2.2
filelock                     3.18.0
flatbuffers                  25.2.10
fsspec                       2025.3.0
gast                         0.6.0
google-pasta                 0.2.0
grpcio                       1.71.0
h5py                         3.13.0
huggingface-hub              0.29.3
idna                         3.10
importlib_metadata           8.6.1
importlib_resources          6.5.2
iniconfig                    2.1.0
jax                          0.4.30
jaxlib                       0.4.30
Jinja2                       3.1.6
joblib                       1.4.2
kagglehub                    0.3.10
keras                        3.9.0
keras-hub                    0.20.0.dev0  /home/amax/keras-hub
libclang                     18.1.1
Markdown                     3.7
markdown-it-py               3.0.0
MarkupSafe                   3.0.2
mdurl                        0.1.2
ml_dtypes                    0.5.1
mpmath                       1.3.0
namex                        0.0.8
networkx                     3.2.1
nltk                         3.9.1
numpy                        2.0.2
nvidia-cublas-cu12           12.6.4.1
nvidia-cuda-cupti-cu12       12.6.80
nvidia-cuda-nvrtc-cu12       12.6.77
nvidia-cuda-runtime-cu12     12.6.77
nvidia-cudnn-cu12            9.5.1.17
nvidia-cufft-cu12            11.3.0.4
nvidia-curand-cu12           10.3.7.77
nvidia-cusolver-cu12         11.7.1.2
nvidia-cusparse-cu12         12.5.4.2
nvidia-cusparselt-cu12       0.6.3
nvidia-nccl-cu12             2.21.5
nvidia-nvjitlink-cu12        12.6.85
nvidia-nvtx-cu12             12.6.77
opt_einsum                   3.4.0
optree                       0.14.1
packaging                    24.2
pillow                       11.1.0
pip                          25.0.1
pluggy                       1.5.0
promise                      2.3
protobuf                     3.20.3
psutil                       7.0.0
Pygments                     2.19.1
pyproject_hooks              1.2.0
pytest                       8.3.5
pytest-cov                   6.0.0
PyYAML                       6.0.2
regex                        2024.11.6
requests                     2.32.3
rich                         13.9.4
rouge_score                  0.1.2
ruff                         0.11.2
safetensors                  0.5.3
scipy                        1.13.1
sentencepiece                0.2.0
setuptools                   78.1.0
six                          1.17.0
sympy                        1.13.1
tensorboard                  2.18.0
tensorboard-data-server      0.7.2
tensorflow                   2.18.1
tensorflow_cpu               2.18.1
tensorflow-datasets          4.9.3
tensorflow-io-gcs-filesystem 0.37.1
tensorflow-metadata          1.16.1
tensorflow-text              2.18.1
termcolor                    2.5.0
toml                         0.10.2
tomli                        2.2.1
torch                        2.6.0+cu126
torchvision                  0.21.0+cu126
tqdm                         4.67.1
triton                       3.2.0
typing_extensions            4.13.0
urllib3                      2.3.0
Werkzeug                     3.1.3
wheel                        0.45.1
wrapt                        1.17.2
zipp                         3.21.0

python version is 3.9
I can't reproduce the Torch test in this environment. The GPU I'm using locally is A100.
image

@mattdangerw

@mattdangerw
Copy link
Member

Sorry for the delay! Looking now.

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Mar 28, 2025
@kokoro-team kokoro-team removed kokoro:force-run Runs Tests on GPU labels Mar 28, 2025
@mattdangerw
Copy link
Member

I can't repro this locally on a 3090 either. Maybe for now to unblock, let's just annotate the failing tests with...

if keras.config.backend() == "torch":
    import torch

    if torch.cuda.device_count():
        self.skipTest("Failing on GPU on CI")

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more comments as we try out #2145 (comment)

@@ -0,0 +1,212 @@
import keras
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not capitalize any path names. Class names can stay as is, but move this to models/roformer_v2/roformer_v2_backbone.py (and so on).


self.token_embedding = token_embedding
self.activation = keras.activations.get(activation)
assert token_embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace assert with ValueError or remove



@keras_hub_export("keras_hub.models.RoformerV2MaskedLMPreprocessor")
class RoformerV2MaskedLMPreprocessor(BertMaskedLMPreprocessor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't subclass bert here, subclass MaskedLMPreprocessor instead.

@@ -0,0 +1,122 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_hub.src.models.roformerV2 import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keras.saving.register_keras_serializable(package="keras_hub")
class RoformerAttention(keras.layers.Layer):
"""
MultiHeadAttention by roformerV2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All docstrings should either look like

"""Single line docstring that fits line limit."""

or

"""Short one line description.

More context (like a link).
"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants