Skip to content

Commit 1cecdb7

Browse files
Interleaving is now optional with BinaryConvLayers.
1 parent b82fd13 commit 1cecdb7

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

cai/layers.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def GlobalAverageMaxPooling2D(previous_layer, name=None):
117117
keras.layers.GlobalMaxPooling2D(name=name+'_max')(previous_layer)
118118
])
119119

120-
def FitChannelCountTo(last_tensor, next_channel_count, channel_axis=3):
120+
def FitChannelCountTo(last_tensor, next_channel_count, has_interleaving=False, channel_axis=3):
121121
prev_layer_channel_count = keras.backend.int_shape(last_tensor)[channel_axis]
122122
full_copies = next_channel_count // prev_layer_channel_count
123123
extra_channels = next_channel_count % prev_layer_channel_count
@@ -126,9 +126,15 @@ def FitChannelCountTo(last_tensor, next_channel_count, channel_axis=3):
126126
if copy_cnt == 0:
127127
output_copies.append( last_tensor )
128128
else:
129-
output_copies.append( InterleaveChannels(step_size=((copy_cnt+1) % prev_layer_channel_count))(last_tensor) )
129+
if has_interleaving:
130+
output_copies.append( InterleaveChannels(step_size=((copy_cnt+1) % prev_layer_channel_count))(last_tensor) )
131+
else:
132+
output_copies.append( last_tensor )
130133
if (extra_channels > 0):
131-
extra_tensor = InterleaveChannels(step_size=((full_copies+1) % prev_layer_channel_count))(last_tensor)
134+
if has_interleaving:
135+
extra_tensor = InterleaveChannels(step_size=((full_copies+1) % prev_layer_channel_count))(last_tensor)
136+
else:
137+
extra_tensor = last_tensor
132138
output_copies.append( CopyChannels(0,extra_channels)(extra_tensor) )
133139
last_tensor = keras.layers.Concatenate(axis=channel_axis)( output_copies )
134140
return last_tensor
@@ -142,30 +148,33 @@ def EnforceEvenChannelCount(last_tensor, channel_axis=3):
142148
channel_axis=channel_axis)
143149
return last_tensor
144150

145-
def BinaryConvLayers(last_tensor, name, shape=(3, 3), conv_count=1, has_batch_norm=True, activation='relu', channel_axis=3):
151+
def BinaryConvLayers(last_tensor, name, shape=(3, 3), conv_count=1, has_batch_norm=True, has_interleaving=False, activation='relu', channel_axis=3):
146152
last_tensor = EnforceEvenChannelCount(last_tensor)
147153
prev_layer_channel_count = keras.backend.int_shape(last_tensor)[channel_axis]
148154
for conv_cnt in range(conv_count):
149155
input_tensor = last_tensor
150-
last_tensor_interleaved = InterleaveChannels(step_size=2, name=name+"_i_"+str(conv_cnt))(last_tensor)
156+
if has_interleaving:
157+
last_tensor_interleaved = InterleaveChannels(step_size=2, name=name+"_i_"+str(conv_cnt))(last_tensor)
158+
else:
159+
last_tensor_interleaved = last_tensor
151160
x1 = keras.layers.Conv2D(prev_layer_channel_count//2, shape, padding='same', activation=None, name=name+"_a_"+str(conv_cnt), groups=prev_layer_channel_count//2)(last_tensor)
152161
x2 = keras.layers.Conv2D(prev_layer_channel_count//2, shape, padding='same', activation=None, name=name+"_b_"+str(conv_cnt), groups=prev_layer_channel_count//2)(last_tensor_interleaved)
153162
last_tensor = keras.layers.Concatenate(axis=channel_axis, name=name+"_conc_"+str(conv_cnt))([x1,x2])
154163
if has_batch_norm: last_tensor = keras.layers.BatchNormalization(axis=channel_axis, name=name+"_batch_"+str(conv_cnt))(last_tensor)
155-
if activation is not None: last_tensor = keras.layers.Activation(activation=activation, name=name+"_relu_"+str(conv_cnt))(last_tensor)
164+
if activation is not None: last_tensor = keras.layers.Activation(activation=activation, name=name+"_act_"+str(conv_cnt))(last_tensor)
156165
last_tensor = keras.layers.add([input_tensor, last_tensor], name=name+'_add'+str(conv_cnt))
157166
if has_batch_norm: last_tensor = keras.layers.BatchNormalization(axis=channel_axis)(last_tensor)
158167
return last_tensor
159168

160-
def BinaryPointwiseConvLayers(last_tensor, name, conv_count=1, has_batch_norm=True, activation='relu', channel_axis=3):
161-
return BinaryConvLayers(last_tensor, name, shape=(1, 1), conv_count=conv_count, has_batch_norm=has_batch_norm, activation=activation, channel_axis=channel_axis)
169+
def BinaryPointwiseConvLayers(last_tensor, name, conv_count=1, has_batch_norm=True, has_interleaving=False, activation='relu', channel_axis=3):
170+
return BinaryConvLayers(last_tensor, name, shape=(1, 1), conv_count=conv_count, has_batch_norm=has_batch_norm, has_interleaving=has_interleaving, activation=activation, channel_axis=channel_axis)
162171

163172
def BinaryCompressionLayer(last_tensor, name, has_batch_norm=True, activation='relu', channel_axis=3):
164173
last_tensor = EnforceEvenChannelCount(last_tensor)
165174
prev_layer_channel_count = keras.backend.int_shape(last_tensor)[channel_axis]
166175
last_tensor = keras.layers.Conv2D(prev_layer_channel_count//2, (1, 1), padding='same', activation=None, name=name+"_conv", groups=prev_layer_channel_count//2)(last_tensor)
167176
if has_batch_norm: last_tensor = keras.layers.BatchNormalization(axis=channel_axis, name=name+"_batch")(last_tensor)
168-
if activation is not None: last_tensor = keras.layers.Activation(activation=activation, name=name+"_relu")(last_tensor)
177+
if activation is not None: last_tensor = keras.layers.Activation(activation=activation, name=name+"_act")(last_tensor)
169178
return last_tensor
170179

171180
def BinaryCompression(last_tensor, name, target_channel_count, has_batch_norm=True, activation='relu', channel_axis=3):

0 commit comments

Comments
 (0)