Skip to content

Commit dbc1ded

Browse files
committed
add leakyrelu def
1 parent 51ee029 commit dbc1ded

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

ext/ForwardDiffExt.jl

+12
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,15 @@ for f in (:vmapt, :vmapnt, :vmapntt)
377377
end
378378
end
379379
end
380+
381+
if Base.ifelse !== IfElse.ifelse
382+
@inline function NNlib.leakyrelu(x::AbstractSIMD)
383+
fx = float(x)
384+
NNlib.leakyrelu(fx, convert(typeof(fx), NNlib.leakyrelu_a))
385+
end
386+
@inline function NNlib.leakyrelu(x::AbstractSIMD, a)
387+
fx = float(x)
388+
ax = convert(typeof(fx), a * x)
389+
ifelse(x > 0, fx, ax) # max(a*x, x) is 3x slower
390+
end
391+
end

0 commit comments

Comments
 (0)