Skip to content

Commit 1db3cc9

Browse files
committed
local desc
1 parent c68f188 commit 1db3cc9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/ops/math.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,13 @@ end
176176
# TODO Clean this up
177177
for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
178178
@eval @op function $(Symbol("reduce_", reduction))(n::AbstractTensor; axis=nothing, keep_dims=false, name=nothing)
179+
local desc
180+
shape = get_shape(n)
181+
179182
if name == nothing
180183
name = $(capitalize(reduction))
181184
end
182185

183-
shape = get_shape(n)
184186
if axis == nothing && shape.rank_unknown
185187
n = Tensor(n) # TODO: rewrite this
186188
desc_rank = tf.with_op_name(nothing, "Rank") do
@@ -197,14 +199,13 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
197199
add_input(desc_range, delta)
198200
Tensor(Operation(desc_range), 1)
199201
end
200-
desc = tf.with_op_name(nothing, name) do
202+
tf.with_op_name(nothing, name) do
201203
desc = NodeDescription($(capitalize(reduction)))
202204
add_input(desc, n)
203205
add_input(desc, range)
204-
desc
205206
end
206207
else
207-
desc = tf.with_op_name(nothing, name) do
208+
tf.with_op_name(nothing, name) do
208209
if axis == nothing
209210
axis = 1:length(shape.dims)
210211
end
@@ -213,7 +214,6 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
213214
add_input(desc, Tensor(n))
214215
add_input(desc, reduction_indices)
215216
desc["keep_dims"] = keep_dims
216-
desc
217217
end
218218
end
219219
Tensor(Operation(desc), 1)

0 commit comments

Comments
 (0)