Skip to content

Commit 7616d28

Browse files
ljfitzPrashant Kumar
authored and
Prashant Kumar
committed
Add leakyrelu support
1 parent f461a7e commit 7616d28

File tree

5 files changed

+73
-4
lines changed

5 files changed

+73
-4
lines changed

e2e_testing/torchscript/elementwise.py

+18
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,24 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
196196
module.forward(tu.rand(4, 2) - 0.5)
197197

198198
# ==============================================================================
199+
class ElementwiseLeakyReluModule(torch.nn.Module):
200+
def __init__(self):
201+
super().__init__()
202+
203+
@export
204+
@annotate_args([
205+
None,
206+
([-1, -1], torch.float32, True),
207+
])
208+
def forward(self, x):
209+
return torch.ops.aten.leaky_relu(x, negative_slope=0.1)
210+
211+
212+
@register_test_case(module_factory=lambda: ElementwiseLeakyReluModule())
213+
def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
214+
module.forward(tu.rand(4, 2) - 0.5)
215+
216+
# ==============================================================================
199217

200218

201219
class ElementwiseGeluModule(torch.nn.Module):

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

+30
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,36 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
7272
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
7373
}
7474

75+
def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
76+
AllowsTypeRefinement,
77+
HasValueSemantics
78+
]> {
79+
let summary = "Generated op for `aten::leaky_relu : (Tensor, Scalar) -> (Tensor)`";
80+
let arguments = (ins
81+
AnyTorchTensorType:$self,
82+
AnyTorchScalarType:$negative_slope
83+
);
84+
let results = (outs
85+
AnyTorchTensorType:$result
86+
);
87+
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
88+
}
89+
90+
def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
91+
IsTrailingUnderscoreInplaceVariant,
92+
AllowsTypeRefinement
93+
]> {
94+
let summary = "Generated op for `aten::leaky_relu_ : (Tensor, Scalar) -> (Tensor)`";
95+
let arguments = (ins
96+
AnyTorchTensorType:$self,
97+
AnyTorchScalarType:$negative_slope
98+
);
99+
let results = (outs
100+
AnyTorchTensorType:$result
101+
);
102+
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
103+
}
104+
75105
def Torch_AtenLogOp : Torch_Op<"aten.log", [
76106
AllowsTypeRefinement,
77107
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

+23-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
1111

1212
#include "../PassDetail.h"
13+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1314
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
1415
#include "mlir/Dialect/Math/IR/Math.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
13781379
}
13791380
Type elementType = payloadArgs[0].getType();
13801381
Value constZero =
1381-
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
1382+
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
13821383
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
13831384
payloadArgs[0], constZero);
13841385
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
13851386
}
1387+
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
1388+
if (!lrelu.getType()
1389+
.cast<ValueTensorType>()
1390+
.getDtype()
1391+
.isa<mlir::FloatType>()) {
1392+
lrelu.emitError("unimplemented: non-floating point dtype");
1393+
return nullptr;
1394+
}
1395+
Type elementType = payloadArgs[0].getType();
1396+
Value constZero =
1397+
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
1398+
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1399+
payloadArgs[0], constZero);
1400+
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
1401+
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
1402+
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
1403+
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
1404+
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
1405+
}
13861406
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
13871407
if (!gelu.getType()
13881408
.cast<ValueTensorType>()
@@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern {
18121832
LogicalResult
18131833
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
18141834
ConversionPatternRewriter &rewriter) const override {
1815-
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
1835+
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
18161836
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
18171837
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
18181838
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
@@ -2969,7 +2989,7 @@ class ConvertTorchToLinalg
29692989
target.addIllegalOp<AtenBatchNormOp>();
29702990
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
29712991
target.addIllegalOp<
2972-
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
2992+
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
29732993
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
29742994
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
29752995
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,

lib/Dialect/Torch/Transforms/RefineTypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
289289
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
290290
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
291291
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
292-
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
292+
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
293293
return visitBinaryTensorScalarOp(op, operands);
294294
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
295295
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def emit_with_mutating_variants(key, **kwargs):
439439
for key in [
440440
"aten::tanh : (Tensor) -> (Tensor)",
441441
"aten::relu : (Tensor) -> (Tensor)",
442+
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
442443
"aten::log : (Tensor) -> (Tensor)",
443444
"aten::sigmoid : (Tensor) -> (Tensor)",
444445
"aten::sin : (Tensor) -> (Tensor)",

0 commit comments

Comments
 (0)