@@ -176,11 +176,13 @@ end
176
176
# TODO Clean this up
177
177
for reduction in [:sum , :prod , :min , :max , :all , :any , :mean ]
178
178
@eval @op function $ (Symbol (" reduce_" , reduction))(n:: AbstractTensor ; axis= nothing , keep_dims= false , name= nothing )
179
+ local desc
180
+ shape = get_shape (n)
181
+
179
182
if name == nothing
180
183
name = $ (capitalize (reduction))
181
184
end
182
185
183
- shape = get_shape (n)
184
186
if axis == nothing && shape. rank_unknown
185
187
n = Tensor (n) # TODO : rewrite this
186
188
desc_rank = tf. with_op_name (nothing , " Rank" ) do
@@ -197,14 +199,13 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
197
199
add_input (desc_range, delta)
198
200
Tensor (Operation (desc_range), 1 )
199
201
end
200
- desc = tf. with_op_name (nothing , name) do
202
+ tf. with_op_name (nothing , name) do
201
203
desc = NodeDescription ($ (capitalize (reduction)))
202
204
add_input (desc, n)
203
205
add_input (desc, range)
204
- desc
205
206
end
206
207
else
207
- desc = tf. with_op_name (nothing , name) do
208
+ tf. with_op_name (nothing , name) do
208
209
if axis == nothing
209
210
axis = 1 : length (shape. dims)
210
211
end
@@ -213,7 +214,6 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
213
214
add_input (desc, Tensor (n))
214
215
add_input (desc, reduction_indices)
215
216
desc[" keep_dims" ] = keep_dims
216
- desc
217
217
end
218
218
end
219
219
Tensor (Operation (desc), 1 )
0 commit comments