@@ -117,7 +117,7 @@ def GlobalAverageMaxPooling2D(previous_layer, name=None):
117
117
keras .layers .GlobalMaxPooling2D (name = name + '_max' )(previous_layer )
118
118
])
119
119
120
- def FitChannelCountTo (last_tensor , next_channel_count , channel_axis = 3 ):
120
+ def FitChannelCountTo (last_tensor , next_channel_count , has_interleaving = False , channel_axis = 3 ):
121
121
prev_layer_channel_count = keras .backend .int_shape (last_tensor )[channel_axis ]
122
122
full_copies = next_channel_count // prev_layer_channel_count
123
123
extra_channels = next_channel_count % prev_layer_channel_count
@@ -126,9 +126,15 @@ def FitChannelCountTo(last_tensor, next_channel_count, channel_axis=3):
126
126
if copy_cnt == 0 :
127
127
output_copies .append ( last_tensor )
128
128
else :
129
- output_copies .append ( InterleaveChannels (step_size = ((copy_cnt + 1 ) % prev_layer_channel_count ))(last_tensor ) )
129
+ if has_interleaving :
130
+ output_copies .append ( InterleaveChannels (step_size = ((copy_cnt + 1 ) % prev_layer_channel_count ))(last_tensor ) )
131
+ else :
132
+ output_copies .append ( last_tensor )
130
133
if (extra_channels > 0 ):
131
- extra_tensor = InterleaveChannels (step_size = ((full_copies + 1 ) % prev_layer_channel_count ))(last_tensor )
134
+ if has_interleaving :
135
+ extra_tensor = InterleaveChannels (step_size = ((full_copies + 1 ) % prev_layer_channel_count ))(last_tensor )
136
+ else :
137
+ extra_tensor = last_tensor
132
138
output_copies .append ( CopyChannels (0 ,extra_channels )(extra_tensor ) )
133
139
last_tensor = keras .layers .Concatenate (axis = channel_axis )( output_copies )
134
140
return last_tensor
@@ -142,30 +148,33 @@ def EnforceEvenChannelCount(last_tensor, channel_axis=3):
142
148
channel_axis = channel_axis )
143
149
return last_tensor
144
150
145
- def BinaryConvLayers (last_tensor , name , shape = (3 , 3 ), conv_count = 1 , has_batch_norm = True , activation = 'relu' , channel_axis = 3 ):
151
+ def BinaryConvLayers (last_tensor , name , shape = (3 , 3 ), conv_count = 1 , has_batch_norm = True , has_interleaving = False , activation = 'relu' , channel_axis = 3 ):
146
152
last_tensor = EnforceEvenChannelCount (last_tensor )
147
153
prev_layer_channel_count = keras .backend .int_shape (last_tensor )[channel_axis ]
148
154
for conv_cnt in range (conv_count ):
149
155
input_tensor = last_tensor
150
- last_tensor_interleaved = InterleaveChannels (step_size = 2 , name = name + "_i_" + str (conv_cnt ))(last_tensor )
156
+ if has_interleaving :
157
+ last_tensor_interleaved = InterleaveChannels (step_size = 2 , name = name + "_i_" + str (conv_cnt ))(last_tensor )
158
+ else :
159
+ last_tensor_interleaved = last_tensor
151
160
x1 = keras .layers .Conv2D (prev_layer_channel_count // 2 , shape , padding = 'same' , activation = None , name = name + "_a_" + str (conv_cnt ), groups = prev_layer_channel_count // 2 )(last_tensor )
152
161
x2 = keras .layers .Conv2D (prev_layer_channel_count // 2 , shape , padding = 'same' , activation = None , name = name + "_b_" + str (conv_cnt ), groups = prev_layer_channel_count // 2 )(last_tensor_interleaved )
153
162
last_tensor = keras .layers .Concatenate (axis = channel_axis , name = name + "_conc_" + str (conv_cnt ))([x1 ,x2 ])
154
163
if has_batch_norm : last_tensor = keras .layers .BatchNormalization (axis = channel_axis , name = name + "_batch_" + str (conv_cnt ))(last_tensor )
155
- if activation is not None : last_tensor = keras .layers .Activation (activation = activation , name = name + "_relu_ " + str (conv_cnt ))(last_tensor )
164
+ if activation is not None : last_tensor = keras .layers .Activation (activation = activation , name = name + "_act_ " + str (conv_cnt ))(last_tensor )
156
165
last_tensor = keras .layers .add ([input_tensor , last_tensor ], name = name + '_add' + str (conv_cnt ))
157
166
if has_batch_norm : last_tensor = keras .layers .BatchNormalization (axis = channel_axis )(last_tensor )
158
167
return last_tensor
159
168
160
- def BinaryPointwiseConvLayers (last_tensor , name , conv_count = 1 , has_batch_norm = True , activation = 'relu' , channel_axis = 3 ):
161
- return BinaryConvLayers (last_tensor , name , shape = (1 , 1 ), conv_count = conv_count , has_batch_norm = has_batch_norm , activation = activation , channel_axis = channel_axis )
169
+ def BinaryPointwiseConvLayers (last_tensor , name , conv_count = 1 , has_batch_norm = True , has_interleaving = False , activation = 'relu' , channel_axis = 3 ):
170
+ return BinaryConvLayers (last_tensor , name , shape = (1 , 1 ), conv_count = conv_count , has_batch_norm = has_batch_norm , has_interleaving = has_interleaving , activation = activation , channel_axis = channel_axis )
162
171
163
172
def BinaryCompressionLayer (last_tensor , name , has_batch_norm = True , activation = 'relu' , channel_axis = 3 ):
164
173
last_tensor = EnforceEvenChannelCount (last_tensor )
165
174
prev_layer_channel_count = keras .backend .int_shape (last_tensor )[channel_axis ]
166
175
last_tensor = keras .layers .Conv2D (prev_layer_channel_count // 2 , (1 , 1 ), padding = 'same' , activation = None , name = name + "_conv" , groups = prev_layer_channel_count // 2 )(last_tensor )
167
176
if has_batch_norm : last_tensor = keras .layers .BatchNormalization (axis = channel_axis , name = name + "_batch" )(last_tensor )
168
- if activation is not None : last_tensor = keras .layers .Activation (activation = activation , name = name + "_relu " )(last_tensor )
177
+ if activation is not None : last_tensor = keras .layers .Activation (activation = activation , name = name + "_act " )(last_tensor )
169
178
return last_tensor
170
179
171
180
def BinaryCompression (last_tensor , name , target_channel_count , has_batch_norm = True , activation = 'relu' , channel_axis = 3 ):
0 commit comments