Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top 3 HF Presets for Mobilenet #2105

Merged
merged 16 commits into from
Mar 7, 2025

Conversation

pkgoogle
Copy link
Collaborator

Adds these presets:

    "mobilenetv3_small_100": "timm/mobilenetv3_small_100.lamb_in1k",
    "mobilenetv3_large_100.ra_in1k": "timm/mobilenetv3_large_100.ra_in1k",
    "mobilenetv3_large_100.miil_in21k_ft_in1k": "timm/mobilenetv3_large_100.miil_in21k_ft_in1k"

We can adjust the nomenclature but this currently works.

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 19, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 19, 2025
@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 20, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 20, 2025
@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@divyashreepathihalli
Copy link
Collaborator

looks good! can you also please upload the presets on Kaggle

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 21, 2025
@pkgoogle
Copy link
Collaborator Author

remaining test appears to be a timeout/network issue.

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 24, 2025
@mattdangerw
Copy link
Member

I think these test failures are legitimate! Can you take a look? Exception encountered: __init__() missing 2 required positional arguments: 'depthwise_stride' and 'depthwise_residual' -> this is not just timeouts.

@mattdangerw mattdangerw self-requested a review February 24, 2025 20:03
@pkgoogle
Copy link
Collaborator Author

No worries, my original comment was for the previous CI, this one is indeed legitimate (currently it seems like it's picking up an old task.json from kaggle and thus are missing the new args needed to support these latest presets).

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 24, 2025
@divyashreepathihalli
Copy link
Collaborator

the tests are failing since - #2108

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Feb 25, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 25, 2025
@@ -622,6 +622,9 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
kwargs["preprocessor"] = self.load_preprocessor(
cls.preprocessor_cls,
)
if "num_features" not in kwargs and "num_features" in self.config:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this part, we should remove these lines. This is trying to take num_features from the backbone config and proxy it to the classifier? Why are we doing this?

We should remove these lines for sure, but not sure what to suggest instead as I don't really understand the purpose yet.

Copy link
Collaborator Author

@pkgoogle pkgoogle Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's related to the current keras saving issue actually. Basically the new presets have a different configuration for the task classifier head, essentially a different intermediate hidden dim.

So this needs to be defined on creating the keras model from the timm config. This appeared to be the cleanest way to handle it -- another option I considered was passing in the keyword like this:

keras_model = keras_hub.models.ImageClassifier.from_preset(
    "hf://" + timm_name, num_features=num_features
)

But then the user has to know to do this. Is there a better place to handle loading arbitrary task configuration for different presets?

Just to add more data, num_features is actually not a backbone config, it's a model config from timm: https://huggingface.co/timm/mobilenetv3_large_100.miil_in21k/blob/main/config.json.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks for the explainer!

We definitely don't want this in the base PresetLoader this would apply to all models from all sources, and num_features is a very generic name so that's a bug waiting to happen. We could add this to the TimmPresetLoader here, as it sounds like this is timm specific.

But I wonder if there's more to figure out here... Resnet has has this same num_features field in timm here, but we aren't handling it that I can see. Does this change break resnet? Do we want an update to the the base ImageClassifier class to handle this type of post-pooling convolution? Let's make sure we are doing this in a way that is consistent across our image classification models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can explain why this isn't popping up anywhere else right now -- as far as I know, no other model changes the classification head's configuration beyond default, so the default is always picked up... it's actually not read dynamically from the config currently (I think), just applied from the default. But the TimmPresetLoader solution will work for Timm models and probably this PR, but I think currently there's no clear way to change image classification heads at least from the config -- but maybe there is a different path for non-Timm models?

I think updating the base class won't solve the issue as every classifier head is a little different -- we seem to be handling this fine via subclassing though there is a weird inheritance pattern we have (calling Task.init directly, as far as I can tell it's because how we define the functional model and the order of calling the super constructors).

Copy link
Collaborator Author

@pkgoogle pkgoogle Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the current CI will still fail for the same reason as before, but I did make this update.

Current issue:

_________________________________________________________________ MobileNetImageClassifierTest.test_all_presets __________________________________________________________________

self = <keras_hub.src.models.mobilenet.mobilenet_image_classifier_test.MobileNetImageClassifierTest testMethod=test_all_presets>

    @pytest.mark.large
    def test_all_presets(self):
        for preset in MobileNetImageClassifier.presets:
>           self.run_preset_test(
                cls=MobileNetImageClassifier,
                preset=preset,
                input_data=self.images,
                expected_output_shape=(2, 1000),
            )

keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py:90: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
keras_hub/src/tests/test_case.py:647: in run_preset_test
    instance = cls.from_preset(preset, **init_kwargs)
keras_hub/src/models/task.py:198: in from_preset
    return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
keras_hub/src/utils/preset_utils.py:696: in load_task
    task.load_task_weights(task_weights)
keras_hub/src/models/task.py:208: in load_task_weights
    keras.saving.load_weights(
../../miniforge3/envs/keras_migration/lib/python3.10/site-packages/keras/src/saving/saving_api.py:251: in load_weights
    saving_lib.load_weights_only(
../../miniforge3/envs/keras_migration/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:611: in load_weights_only
    _raise_loading_failure(error_msgs, warn_only=skip_mismatch)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

error_msgs = {6045933600: (<Dense name=predictions, built=True>, ValueError("Layer 'predictions' expected 2 variables, but received...Layer 'classifier_conv' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias']"))}
warn_only = False

    def _raise_loading_failure(error_msgs, warn_only=False):
        first_key = list(error_msgs.keys())[0]
        ex_saveable, ex_error = error_msgs[first_key]
        msg = (
            f"A total of {len(error_msgs)} objects could not "
            "be loaded. Example error message for "
            f"object {ex_saveable}:\n\n"
            f"{ex_error}\n\n"
            "List of objects that could not be loaded:\n"
            f"{[x[0] for x in error_msgs.values()]}"
        )
        if warn_only:
            warnings.warn(msg)
        else:
>           raise ValueError(msg)
E           ValueError: A total of 2 objects could not be loaded. Example error message for object <Conv2D name=classifier_conv, built=True>:
E           
E           Layer 'classifier_conv' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias']
E           
E           List of objects that could not be loaded:
E           [<Conv2D name=classifier_conv, built=True>, <Dense name=predictions, built=True>]

It seems like Keras doesn't save the task weights correctly and it fails to find the right weights on loading... any idea on how to dig into this? I also want to note.. this only happens for the "large" presets, the "small" ones still work fine.

Copy link
Member

@mattdangerw mattdangerw Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but none of the saved task.json's have a num_features key. So what I'd guess...

  • We have not resaved the entire classifier with these changes and uploaded. classifier.save_to_preset(path).
  • Because the earlier versions of this model did not have a num_features key, we load a task with the default num_features=1024, it expect some post pooling conv weights, but hey are not there, because these files were saved before this compat breaking change.

But before we get too far on this. Is there a reason we don't have this same conv for say resnet 50? I see num_features: 2048 in the timm config for that model. Basically, I'm wondering if this is something we should fix for all our image classifier models, in which case we should add support for this into the base ImageClassifier. If not, we can add special cased to mobilenet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I'm not sure but I can answer that question at least. Digging into the timm implementation, https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/classifier.py#L47 This is how they define the classifier head that Resnet uses. num_features is usually the input features to the pooling layer. This is implicit in Keras. For the MobileNet Classifier it's used multiple times: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py#L120, the way I'm currently using it best matches num_pooled_chs = self.num_features * self.global_pool.feat_mult(). The feature multiplier is usually just 1: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/adaptive_avgmax_pool.py#L23. So one possible solution.. is to just grab it from the backbone instead of the config... that would match the rest of the implementations better -- though I do sort of feel like there should be some way to configure the head from the config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I can't get it from the backbone -- Mobilenetv3 actually uses num_features differently, self.head_hidden_size = num_features # features of conv_head, pre_logits output The hidden size is grabbed from the argument (which matches the config), not the redefined class variable self.num_features, which is indeed the input channels for the pooling layer. So I think I will see if I can solve the the saving/loading issue for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, the smaller issue was an outdated task.json, all presets are working at least locally now. Thanks!

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Feb 28, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 28, 2025
@mattdangerw
Copy link
Member

Sounds good! Let's go for this. Does look like timm allows the final projection to be a conv instead of a dense, we could consider adding something like that at some point. But since mobilenet is pool -> conv -> dense that seems worth keeping as it's own subclass. Thanks!

@@ -51,6 +51,9 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
)
# Support loading the classification head for classifier models.
kwargs["num_classes"] = self.config["num_classes"]
if "num_features" in self.config:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops actually sorry, one more issue. This will forward num_features to all classes, even for classifiers that don't support the kwarg. Maybe switch this to if "num_features" in self.config and "mobilenet" in self.config["architecture"]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pkgoogle pkgoogle added the kokoro:force-run Runs Tests on GPU label Mar 5, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 5, 2025
@mattdangerw mattdangerw merged commit 5b677e3 into keras-team:master Mar 7, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants