Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RF-DETR #36895

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Add RF-DETR #36895

wants to merge 3 commits into from

Conversation

sbucaille
Copy link
Contributor

What does this PR do?

Implements RF-DETR

Fixes #36879

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@qubvel

Sorry, something went wrong.

@sbucaille
Copy link
Contributor Author

Just a small message to present the architecture and what it looks like from 🤗 transformers point of view :
image

RF-DETR is based on LW-DETR and DeformableDETR. The LW-DETR is based on DETR but modified the encoder to be a ViT instead of a CNN (like ResNet) and they added the appropriate MultiScaleProjector to make the link between the encoder and the decoder. RF-DETR changed in LW-DETR the encoder from a ViT to DinoV2WithRegisters with a "window" mechanism as well as changed the classical DETR decoder by a DeformableDETR decoder.

There is basically 2 things to write :

  • The RFDetrMultiScaleProjector which was originally implemented in LWDetr which RFDetr is based on but not present in the library.
  • The RFDetrBackbone with the underlying classes built on top of DinoV2WithRegisters

One difficulty I may see in advance is the following :
Am I right saying that there should be only one XXPreTrainedModel class per modeling file ?
In our case, we will need to create a RFDetrBackbone class which requires a PreTrainedModel to be considered as an AutoBackbone, presumably inheritating from DinoV2WithRegistersPreTrainedModel. We will also need to create RFDetrModel and RFDetrForObjectDetection which both require a PreTrainedModel, which will likely inherit from DeformableDetrPreTrainedModel.
If yes then I need to "merge" both classes but _supports_flash_attn_2 is not the same for both, DinoV2 supports it but not DeformableDetr.

I noticed your PR about refactoring attention in ViTs, is there any plan for other models such as Detr, RTDetr etc to add FlashAttention ?
I guess for now I'll just set _supports_flash_attn_2 as false.

Let me know what you guys think

@qubvel
Copy link
Member

qubvel commented Mar 22, 2025

Hi @sbucaille, thanks for the detailed write-up!

Am I right saying that there should be only one XXPreTrainedModel class per modeling file ? In our case, we will need to create a RFDetrBackbone class which requires a PreTrainedModel to be considered as an AutoBackbone, presumably inheritating from DinoV2WithRegistersPreTrainedModel. We will also need to create RFDetrModel and RFDetrForObjectDetection which both require a PreTrainedModel, which will likely inherit from DeformableDetrPreTrainedModel.

We can add DinoV2WithRegistersBackbone class directly into the dino_v2_with_registers model, would it work?

I noticed your #36545 about refactoring attention in ViTs, is there any plan for other models such as Detr, RTDetr etc to add FlashAttention ?

Not at the moment, from my experiments it was not required for detr-based models and did not give any speedup. However, it might be more relevant for transformer-based encoder. Let's keep it simple initially and set it to False as you suggested

@sbucaille
Copy link
Contributor Author

We can't use the DinoV2WithRegistersBackbone as there is the "window" mechanism in the middle of the forward methods, here is an example (not final) :

class RFDetrBackboneLayer(Dinov2WithRegistersLayer):
    def __init__(self, config):
        super(Dinov2WithRegistersLayer).__init__(config)

        self.num_windows = config.num_windows

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        run_full_attention: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        assert head_mask is None, "head_mask is not supported for windowed attention"
        assert not output_attentions, "output_attentions is not supported for windowed attention"
        shortcut = hidden_states
        if run_full_attention:
            # reshape x to remove windows
            B, HW, C = hidden_states.shape
            num_windows_squared = self.num_windows**2
            hidden_states = hidden_states.view(B // num_windows_squared, num_windows_squared * HW, C)

        self_attention_outputs = self.attention(
            self.norm1(hidden_states),  # in Dinov2WithRegisters, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]

        if run_full_attention:
            # reshape x to add windows back
            B, HW, C = hidden_states.shape
            num_windows_squared = self.num_windows**2
            # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C)
            attention_output = attention_output.view(B * num_windows_squared, HW // num_windows_squared, C)

        attention_output = self.layer_scale1(attention_output)
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection
        hidden_states = self.drop_path(attention_output) + shortcut

        # in Dinov2WithRegisters, layernorm is also applied after self-attention
        layer_output = self.norm2(hidden_states)
        layer_output = self.mlp(layer_output)
        layer_output = self.layer_scale2(layer_output)

        # second residual connection
        layer_output = self.drop_path(layer_output) + hidden_states

        outputs = (layer_output,) + outputs

        return outputs

That's why I think we necessarily need a custom Backbone class for that 🤔

@qubvel
Copy link
Member

qubvel commented Mar 24, 2025

Hmm, am I correct that this part was added?

if run_full_attention:
            # reshape x to add windows back
            B, HW, C = hidden_states.shape
            num_windows_squared = self.num_windows**2
            # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C)
            attention_output = attention_output.view(B * num_windows_squared, HW // num_windows_squared, C)

It looks like it is a reshape only operation, we can return attention_output as is, and reshape all layers output later, right?

@sbucaille
Copy link
Contributor Author

You are right, but it is not the only example. I'll stick to my original plan until I have something running with actual results and I'll take care of refactoring this part later, I'll ping you when it's ready.

@sbucaille
Copy link
Contributor Author

Hey @qubvel, in the end I made modeling files follow the rt_detr folder structure with modeling_rf_detr_dinov2_with_registers.py being like the modeling_rt_detr_resnet.py where the backbone is defined and modeling_rf_detr.py like modeling_rt_detr.py where the encoder/decoder is defined with the use of any possible backbone.
What do you think ? I'll continue tonight

Also I had issues with the modular mechanism where utils/modular_model_converter.py always enforced the name RfDetr... instead of RFDetr... as I wanted to follow the naming convention used for RTDetr. So I ended up using the copied from mechanism.

@qubvel
Copy link
Member

qubvel commented Mar 25, 2025

Hey, let's use RfDert name + modular, it's ok! RfDetr is a correct naming format while RTDetr is an exception made before modular was introduced

@sbucaille
Copy link
Contributor Author

Ok sorry I confused the problems I had, I didn't have a problem with the capital letters of Rf or RF but rather a problem with the prefix RfDetr used in modular file with DeformableDetr
Using :

class RfDetrModel(DeformableDetrModel):
    pass

generates a bunch of Rf... classes like RfConvEncoder instead of RfDetrConvEncoder. I supposed RfDetr and DeformableDetr share Detr in their name which the modular script failing, I don't have the case when using RfDetrDinov2WithRegisters(Dinov2WithRegisters) naturally. The way I can avoid this is by forcing the naming of these classes by overwriting the __init__ method like this :

class RfDetrConvEncoder(DeformableDetrConvEncoder):
    pass

class RfDetrModel(DeformableDetrModel):
    def __init__(self, config: RfDetrConfig):
        super().__init__(config)

        backbone = RfDetrConvEncoder(config)
        ...

But the problem also appears for ModelOutput's, so I'm forced to rewrite the whole forward methods for many classes, which makes using modular a bit useless in my opinion...
So for modeling_rf_detr_dinov2_with_registers.py, I can keep a modular file but not for modeling_rf_detr.py, I think I'll need to use the copied from mechanism.

Should I open an issue ? Maybe @ArthurZucker have some insights on this problem ?

@qubvel
Copy link
Member

qubvel commented Mar 25, 2025

cc @Cyrilvallez re modular you faced somthing similar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add RF-DETR model
2 participants