diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py index 6447ca2fc5..6c524dd214 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py @@ -96,7 +96,7 @@ class PaliGemmaBackbone(Backbone): } # Pretrained PaliGemma decoder. - model = keras_hub.models.PaliGemmaBackbone.from_preset("pali_gemma_mix_224") + model = keras_hub.models.PaliGemmaBackbone.from_preset("pali_gemma_3b_mix_224") model(input_data) # Randomly initialized PaliGemma decoder with custom config. diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index ffcf3ecafd..abf4ec0d7c 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -1,5 +1,7 @@ """PaliGemma model preset configurations.""" +import re + # Metadata for loading pretrained model weights. backbone_presets = { "pali_gemma_3b_mix_224": { @@ -53,7 +55,7 @@ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", }, # PaliGemma2 - "pali_gemma2_3b_ft_docci_448": { + "pali_gemma2_ft_docci_3b_448": { "metadata": { "description": ( "3 billion parameter, image size 448, 27-layer for " @@ -66,9 +68,10 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_ft_docci_448/1", + # TODO: Rename `pali_gemma_2_ft_docci_3b_448` to `pali_gemma2_ft_docci_3b_448` + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma_2_ft_docci_3b_448/1", }, - "pali_gemma2_10b_ft_docci_448": { + "pali_gemma2_ft_docci_10b_448": { "metadata": { "description": ( "10 billion parameter, 27-layer for SigLIP-So400m vision " @@ -81,7 +84,7 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/1", }, "pali_gemma2_pt_3b_224": { "metadata": { @@ -219,3 +222,34 @@ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/1", }, } + +# Ensure compatibility with the official naming convention. +# pali_gemma2_[3|10|28b]_[variant]_[image_size] +compatible_preset_names = [] +for preset_name in backbone_presets.keys(): + if re.match(r"pali_gemma2_(.+)_(.+)_(.+)_(.+)", preset_name): + # Ex: pali_gemma2_ft_docci_3b_448 -> pali_gemma2_3b_ft_docci_448 + compatible_preset_names.append( + ( + re.sub( + r"pali_gemma2_(.+)_(.+)_(.+)_(.+)", + r"pali_gemma2_\3_\1_\2_\4", + preset_name, + ), + preset_name, + ) + ) + elif re.match(r"pali_gemma2_(.+)_(.+)_(.+)", preset_name): + # Ex: pali_gemma2_pt_3b_224 -> pali_gemma2_3b_pt_224 + compatible_preset_names.append( + ( + re.sub( + r"pali_gemma2_(.+)_(.+)_(.+)", + r"pali_gemma2_\2_\1_\3", + preset_name, + ), + preset_name, + ) + ) +for compatible_preset_name, preset_name in compatible_preset_names: + backbone_presets[compatible_preset_name] = backbone_presets[preset_name]