-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement of roformerv2 #2145
base: master
Are you sure you want to change the base?
implement of roformerv2 #2145
Conversation
Aother, i am author of pr. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @pass-lin ! Nice contribution.
Please do add some testing for this, you can use bert as an example. And update those tests as needed till pytest passes. You can skip anything related to testing from_preset
(these are usually marked with a "large" marked on the tests), we will need uploaded weights to kaggle before we can add those. That can be a follow up I think.
self.max_wavelength = max_wavelength | ||
self.output_dim = output_dim | ||
|
||
def call(self, tensors: list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally remove type annotations, we aren't using them at least yet!
return tensor[indices] | ||
|
||
def sinusoidal_embeddings(self, pos, dim, base=10000): | ||
assert dim % 2 == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer valueerror with a readable message over asserts.
|
||
@keras_hub_export("keras_hub.models.RorformerV2Backbone") | ||
class RoformerV2Backbone(Backbone): | ||
"""A BERT encoder network. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update from bert!
x = self.embeddings_dropout(x) | ||
for transformer_layer in self.transformer_layers: | ||
x = transformer_layer(x, attention_mask=padding_mask_input) | ||
# Construct the two BERT outputs. The pooled output is a dense layer on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably just find replace bert -> roformer (will stop commenting on this this one)
return y | ||
|
||
def get_config(self): | ||
config = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we've switched to another style for the most part. config = super().get_config()
-> config.update({...})
-> return.
@keras_hub_export( | ||
[ | ||
"keras_hub.models.RoformerV2TextClassifierPreprocessor", | ||
"keras_hub.models.RoformerV2Preprocessor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can ditch this second export name. we only had it in bert for backwards compat.
@keras_hub_export( | ||
[ | ||
"keras_hub.models.RorformerV2TextClassifier", | ||
"keras_hub.models.RorformerV2Classifier", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, ditch this name, we only needed it for backwards compat
@mattdangerw |
@pass-lin approved running the tests! Though finding a way to run some CPU testing locally will probably speed things up greatly, even if you use our CI for GPU tests. What's you local development machine? |
I've already been able to successfully run and pass the tests locally. But my model relies on rms ln, which was newly added in keras3.9. Now the keras version for online testing is not so new, so it will cause the test to fail. |
I've made some changes, removing the dependency on RMS norm and directly copying an implementation of llama norm. |
I'm not quite sure how to run jax tests locally. My current method modifies the keras source code to make the default backend jax. In this case I can pass the test locally |
@pass-lin two ways... Either go CPU only, pip install Or do the same with GPU support, |
I accidentally entered an extra max_sequence_length parameter in the test script. This exists in BERT, but not in RoFormer. I've already deleted it. This seems to be the reason for the failure of JAX online testing. |
@mattdangerw |
I might need some help. The error in JAX is due to a version issue. Do I need to implement a native attention mechanism for Keras 3.5 and below? |
@mattdangerw import torch
from torch import Tensor, nn
class SelfAttention(nn.Module):
def __init__(
self,
num_attention_heads: int = 12,
hidden_size: int = 768,
attention_probs_dropout_prob: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.dropout_prob = attention_probs_dropout_prob
def transpose_for_scores(self, x: Tensor) -> Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
return x.view(new_x_shape).permute(0, 2, 1, 3)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)
return attn_output
def test_attention():
device = torch.device("cuda")
num_attention_heads = 8
hidden_size = 512
attention_probs_dropout_prob = 0.0
model = SelfAttention(
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
).to(device)
model = torch.compile(model)
batch_size = 8
length = 1
inputs_embeds = torch.randn(batch_size, length, hidden_size, device=device)
attention_mask = torch.ones(batch_size, 1, length, length, device=device)
attn_output = model(hidden_states=inputs_embeds, attention_mask=attention_mask)[0]
loss = attn_output.mean()
loss.backward()
test_attention() It seems to be a problem caused by torch.compile. So is it possible that the bug in our case is also caused by torch.compile? |
Sorry for the delay! Looking now. |
I can't repro this locally on a 3090 either. Maybe for now to unblock, let's just annotate the failing tests with...
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple more comments as we try out #2145 (comment)
@@ -0,0 +1,212 @@ | |||
import keras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not capitalize any path names. Class names can stay as is, but move this to models/roformer_v2/roformer_v2_backbone.py
(and so on).
|
||
self.token_embedding = token_embedding | ||
self.activation = keras.activations.get(activation) | ||
assert token_embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace assert with ValueError or remove
|
||
|
||
@keras_hub_export("keras_hub.models.RoformerV2MaskedLMPreprocessor") | ||
class RoformerV2MaskedLMPreprocessor(BertMaskedLMPreprocessor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't subclass bert here, subclass MaskedLMPreprocessor
instead.
@@ -0,0 +1,122 @@ | |||
from keras_hub.src.api_export import keras_hub_export | |||
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier | |||
from keras_hub.src.models.roformerV2 import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just leave this like the rest of the imports and add # noqa: E501
to the line.
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py#L7
@keras.saving.register_keras_serializable(package="keras_hub") | ||
class RoformerAttention(keras.layers.Layer): | ||
""" | ||
MultiHeadAttention by roformerV2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All docstrings should either look like
"""Single line docstring that fits line limit."""
or
"""Short one line description.
More context (like a link).
"""
from issue 2118
I've provided an implementation of RoFormer V2 in Keras_Hub. It's worth noting that this is a model implemented natively in Keras 2.3.1+TF 1.15. Therefore, I tried to reuse the original author's implementation as much as possible when implementing it.
The current Keras_hub lacks Chinese models that support Flash-Attention and are not limited to the length of the encoder-only. We believe that RoFormer V2 can to some extent make up for this.
Since this is a native Keras model, and HF doesn't provide an implementation, I've additionally provided a weight file.
roformerV2-small
roformerV2-base
roformerV2-large
The accuracy difference between the version I implemented and the original one is kept below 1e-5. Moreover, the LN layer in the original version is quite special, so I used an RMS layer initialized as ones as a substitute. They are equivalent.
Finally, I ran the following scripts to test whether my implementation is complete. The results show that they all run smoothly.
But I don't quite understand how to convert them into test cases.