Skip to content

Commit 5b677e3

Browse files
authored
Add top 3 HF Presets for Mobilenet (#2105)
* test and preset fixes * add version_number to kaggle_handle * add top3 HF presets * fix the new arguments * fix new arguments to other test * update/add mobilenet presets * update model nomenclature * define __init__.py for mobilenet, further fix nomenclature * remove extra line * update expected output (batch size mismatch) on test * classifier definition slight refactor * include more specific condition
1 parent c9b2ee1 commit 5b677e3

9 files changed

+202
-50
lines changed

Diff for: keras_hub/src/models/mobilenet/mobilenet_backbone.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class DepthwiseConvBlock(keras.layers.Layer):
142142
signal into before reexciting back out. If (>1) technically, it's an
143143
excite & squeeze layer. If this doesn't exist there is no
144144
SqueezeExcite layer.
145+
residual: bool, default False. True if we want a residual connection. If
146+
False, there is no residual connection.
145147
name: str, name of the layer
146148
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
147149
to use for the model's computations and weights.
@@ -161,6 +163,7 @@ def __init__(
161163
kernel_size=3,
162164
stride=2,
163165
squeeze_excite_ratio=None,
166+
residual=False,
164167
name=None,
165168
dtype=None,
166169
**kwargs,
@@ -171,6 +174,7 @@ def __init__(
171174
self.kernel_size = kernel_size
172175
self.stride = stride
173176
self.squeeze_excite_ratio = squeeze_excite_ratio
177+
self.residual = residual
174178
self.name = name
175179

176180
channel_axis = (
@@ -256,11 +260,15 @@ def call(self, inputs):
256260
x = self.batch_normalization1(x)
257261
x = self.activation1(x)
258262

259-
if self.se_layer:
263+
if self.squeeze_excite_ratio:
260264
x = self.se_layer(x)
261265

262266
x = self.conv2(x)
263267
x = self.batch_normalization2(x)
268+
269+
if self.residual:
270+
x = x + inputs
271+
264272
return x
265273

266274
def get_config(self):
@@ -272,6 +280,7 @@ def get_config(self):
272280
"kernel_size": self.kernel_size,
273281
"stride": self.stride,
274282
"squeeze_excite_ratio": self.squeeze_excite_ratio,
283+
"residual": self.residual,
275284
"name": self.name,
276285
}
277286
)
@@ -675,6 +684,8 @@ def __init__(
675684
stackwise_padding,
676685
output_num_filters,
677686
depthwise_filters,
687+
depthwise_stride,
688+
depthwise_residual,
678689
last_layer_filter,
679690
squeeze_and_excite=None,
680691
image_shape=(None, None, 3),
@@ -722,7 +733,9 @@ def __init__(
722733
x = DepthwiseConvBlock(
723734
input_num_filters,
724735
depthwise_filters,
736+
stride=depthwise_stride,
725737
squeeze_excite_ratio=squeeze_and_excite,
738+
residual=depthwise_residual,
726739
name="block_0",
727740
dtype=dtype,
728741
)(x)
@@ -768,6 +781,8 @@ def __init__(
768781
self.input_num_filters = input_num_filters
769782
self.output_num_filters = output_num_filters
770783
self.depthwise_filters = depthwise_filters
784+
self.depthwise_stride = depthwise_stride
785+
self.depthwise_residual = depthwise_residual
771786
self.last_layer_filter = last_layer_filter
772787
self.squeeze_and_excite = squeeze_and_excite
773788
self.input_activation = input_activation
@@ -790,6 +805,8 @@ def get_config(self):
790805
"input_num_filters": self.input_num_filters,
791806
"output_num_filters": self.output_num_filters,
792807
"depthwise_filters": self.depthwise_filters,
808+
"depthwise_stride": self.depthwise_stride,
809+
"depthwise_residual": self.depthwise_residual,
793810
"last_layer_filter": self.last_layer_filter,
794811
"squeeze_and_excite": self.squeeze_and_excite,
795812
"input_activation": self.input_activation,

Diff for: keras_hub/src/models/mobilenet/mobilenet_backbone_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def setUp(self):
5353
"input_num_filters": 16,
5454
"image_shape": (32, 32, 3),
5555
"depthwise_filters": 8,
56+
"depthwise_stride": 2,
57+
"depthwise_residual": False,
5658
"squeeze_and_excite": 0.5,
5759
"last_layer_filter": 288,
5860
}

Diff for: keras_hub/src/models/mobilenet/mobilenet_image_classifier.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self,
1919
backbone,
2020
num_classes,
21+
num_features=1024,
2122
preprocessor=None,
2223
head_dtype=None,
2324
**kwargs,
@@ -33,7 +34,7 @@ def __init__(
3334
)
3435

3536
self.output_conv = keras.layers.Conv2D(
36-
filters=1024,
37+
filters=num_features,
3738
kernel_size=(1, 1),
3839
strides=(1, 1),
3940
use_bias=True,
@@ -69,13 +70,15 @@ def __init__(
6970

7071
# === Config ===
7172
self.num_classes = num_classes
73+
self.num_features = num_features
7274

7375
def get_config(self):
7476
# Skip ImageClassifier
7577
config = Task.get_config(self)
7678
config.update(
7779
{
7880
"num_classes": self.num_classes,
81+
"num_features": self.num_features,
7982
}
8083
)
8184
return config

Diff for: keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def setUp(self):
5454
input_num_filters=16,
5555
image_shape=(32, 32, 3),
5656
depthwise_filters=8,
57+
depthwise_stride=2,
58+
depthwise_residual=False,
5759
squeeze_and_excite=0.5,
5860
last_layer_filter=288,
5961
)

Diff for: keras_hub/src/models/mobilenet/mobilenet_presets.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,48 @@
44
"mobilenet_v3_small_050_imagenet": {
55
"metadata": {
66
"description": (
7-
"Small MobileNet V3 model pre-trained on the ImageNet 1k "
8-
"dataset at a 224x224 resolution."
7+
"Small Mobilenet V3 model pre-trained on the ImageNet 1k "
8+
"dataset at a 224x224 resolution. Has half channel multiplier."
99
),
1010
"params": 278784,
1111
"path": "mobilenetv3",
1212
},
1313
"kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_050_imagenet/1",
1414
},
15+
"mobilenet_v3_small_100_imagenet": {
16+
"metadata": {
17+
"description": (
18+
"Small Mobilenet V3 model pre-trained on the ImageNet 1k "
19+
"dataset at a 224x224 resolution. Has baseline channel "
20+
"multiplier."
21+
),
22+
"params": 939120,
23+
"path": "mobilenetv3",
24+
},
25+
"kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_100_imagenet/1",
26+
},
27+
"mobilenet_v3_large_100_imagenet": {
28+
"metadata": {
29+
"description": (
30+
"Large Mobilenet V3 model pre-trained on the ImageNet 1k "
31+
"dataset at a 224x224 resolution. Has baseline channel "
32+
"multiplier."
33+
),
34+
"params": 2996352,
35+
"path": "mobilenetv3",
36+
},
37+
"kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet/1",
38+
},
39+
"mobilenet_v3_large_100_imagenet_21k": {
40+
"metadata": {
41+
"description": (
42+
"Large Mobilenet V3 model pre-trained on the ImageNet 21k "
43+
"dataset at a 224x224 resolution. Has baseline channel "
44+
"multiplier."
45+
),
46+
"params": 2996352,
47+
"path": "mobilenetv3",
48+
},
49+
"kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet_21k/1",
50+
},
1551
}

Diff for: keras_hub/src/utils/preset_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
622622
kwargs["preprocessor"] = self.load_preprocessor(
623623
cls.preprocessor_cls,
624624
)
625+
625626
return cls(**kwargs)
626627

627628
def load_preprocessor(

Diff for: keras_hub/src/utils/timm/convert_mobilenet.py

+120-44
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,135 @@
88
def convert_backbone_config(timm_config):
99
timm_architecture = timm_config["architecture"]
1010

11-
if "mobilenetv3_" in timm_architecture:
12-
input_activation = "hard_swish"
13-
output_activation = "hard_swish"
14-
else:
15-
input_activation = "relu6"
16-
output_activation = "relu6"
17-
18-
if timm_architecture == "mobilenetv3_small_050":
19-
stackwise_num_blocks = [2, 3, 2, 3]
20-
stackwise_expansion = [
11+
kwargs = {
12+
"stackwise_num_blocks": [2, 3, 2, 3],
13+
"stackwise_expansion": [
2114
[40, 56],
2215
[64, 144, 144],
2316
[72, 72],
2417
[144, 288, 288],
25-
]
26-
stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]]
27-
stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]]
28-
stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]]
29-
stackwise_se_ratio = [
18+
],
19+
"stackwise_num_filters": [
20+
[16, 16],
21+
[24, 24, 24],
22+
[24, 24],
23+
[48, 48, 48],
24+
],
25+
"stackwise_kernel_size": [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]],
26+
"stackwise_num_strides": [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]],
27+
"stackwise_se_ratio": [
3028
[None, None],
3129
[0.25, 0.25, 0.25],
3230
[0.25, 0.25],
3331
[0.25, 0.25, 0.25],
34-
]
35-
stackwise_activation = [
32+
],
33+
"stackwise_activation": [
3634
["relu", "relu"],
3735
["hard_swish", "hard_swish", "hard_swish"],
3836
["hard_swish", "hard_swish"],
3937
["hard_swish", "hard_swish", "hard_swish"],
40-
]
41-
stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]]
42-
output_num_filters = 1024
43-
input_num_filters = 16
44-
depthwise_filters = 8
45-
squeeze_and_excite = 0.5
46-
last_layer_filter = 288
38+
],
39+
"stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]],
40+
"output_num_filters": 1024,
41+
"input_num_filters": 16,
42+
"depthwise_filters": 8,
43+
"depthwise_stride": 2,
44+
"depthwise_residual": False,
45+
"squeeze_and_excite": 0.5,
46+
"last_layer_filter": 288,
47+
"input_activation": "relu6",
48+
"output_activation": "relu6",
49+
}
50+
51+
if "mobilenetv3_" in timm_architecture:
52+
kwargs["input_activation"] = "hard_swish"
53+
kwargs["output_activation"] = "hard_swish"
54+
55+
if timm_architecture == "mobilenetv3_small_050":
56+
pass
57+
elif timm_architecture == "mobilenetv3_small_100":
58+
modified_kwargs = {
59+
"stackwise_expansion": [
60+
[72, 88],
61+
[96, 240, 240],
62+
[120, 144],
63+
[288, 576, 576],
64+
],
65+
"stackwise_num_filters": [
66+
[24, 24],
67+
[40, 40, 40],
68+
[48, 48],
69+
[96, 96, 96],
70+
],
71+
"depthwise_filters": 16,
72+
"last_layer_filter": 576,
73+
}
74+
kwargs.update(modified_kwargs)
75+
elif timm_architecture.startswith("mobilenetv3_large_100"):
76+
modified_kwargs = {
77+
"stackwise_num_blocks": [2, 3, 4, 2, 3],
78+
"stackwise_expansion": [
79+
[64, 72],
80+
[72, 120, 120],
81+
[240, 200, 184, 184],
82+
[480, 672],
83+
[672, 960, 960],
84+
],
85+
"stackwise_num_filters": [
86+
[24, 24],
87+
[40, 40, 40],
88+
[80, 80, 80, 80],
89+
[112, 112],
90+
[160, 160, 160],
91+
],
92+
"stackwise_kernel_size": [
93+
[3, 3],
94+
[5, 5, 5],
95+
[3, 3, 3, 3],
96+
[3, 3],
97+
[5, 5, 5],
98+
],
99+
"stackwise_num_strides": [
100+
[2, 1],
101+
[2, 1, 1],
102+
[2, 1, 1, 1],
103+
[1, 1],
104+
[2, 1, 1],
105+
],
106+
"stackwise_se_ratio": [
107+
[None, None],
108+
[0.25, 0.25, 0.25],
109+
[None, None, None, None],
110+
[0.25, 0.25],
111+
[0.25, 0.25, 0.25],
112+
],
113+
"stackwise_activation": [
114+
["relu", "relu"],
115+
["relu", "relu", "relu"],
116+
["hard_swish", "hard_swish", "hard_swish", "hard_swish"],
117+
["hard_swish", "hard_swish"],
118+
["hard_swish", "hard_swish", "hard_swish"],
119+
],
120+
"stackwise_padding": [
121+
[1, 1],
122+
[2, 2, 2],
123+
[1, 1, 1, 1],
124+
[1, 1],
125+
[2, 2, 2],
126+
],
127+
"depthwise_filters": 16,
128+
"depthwise_stride": 1,
129+
"depthwise_residual": True,
130+
"squeeze_and_excite": None,
131+
"last_layer_filter": 960,
132+
}
133+
kwargs.update(modified_kwargs)
47134
else:
48135
raise ValueError(
49136
f"Currently, the architecture {timm_architecture} is not supported."
50137
)
51138

52-
return dict(
53-
input_num_filters=input_num_filters,
54-
input_activation=input_activation,
55-
depthwise_filters=depthwise_filters,
56-
squeeze_and_excite=squeeze_and_excite,
57-
stackwise_num_blocks=stackwise_num_blocks,
58-
stackwise_expansion=stackwise_expansion,
59-
stackwise_num_filters=stackwise_num_filters,
60-
stackwise_kernel_size=stackwise_kernel_size,
61-
stackwise_num_strides=stackwise_num_strides,
62-
stackwise_se_ratio=stackwise_se_ratio,
63-
stackwise_activation=stackwise_activation,
64-
stackwise_padding=stackwise_padding,
65-
output_num_filters=output_num_filters,
66-
output_activation=output_activation,
67-
last_layer_filter=last_layer_filter,
68-
)
139+
return kwargs
69140

70141

71142
def convert_weights(backbone, loader, timm_config):
@@ -120,9 +191,14 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
120191
port_conv2d(stem_block.conv1, f"{hf_name}.conv_dw")
121192
port_batch_normalization(stem_block.batch_normalization1, f"{hf_name}.bn1")
122193

123-
stem_se_block = stem_block.se_layer
124-
port_conv2d(stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True)
125-
port_conv2d(stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True)
194+
if stem_block.squeeze_excite_ratio:
195+
stem_se_block = stem_block.se_layer
196+
port_conv2d(
197+
stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True
198+
)
199+
port_conv2d(
200+
stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True
201+
)
126202

127203
port_conv2d(stem_block.conv2, f"{hf_name}.conv_pw")
128204
port_batch_normalization(stem_block.batch_normalization2, f"{hf_name}.bn2")

0 commit comments

Comments
 (0)