Skip to content

Commit 6ddeb1a

Browse files
authored
[torch] Add support for aten.selu (#2640)
Add `aten.selu` operation to `torch` dialect.
1 parent 42392bc commit 6ddeb1a

File tree

8 files changed

+157
-0
lines changed

8 files changed

+157
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+45
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,51 @@ def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
346346
}];
347347
}
348348

349+
def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
350+
AllowsTypeRefinement,
351+
HasValueSemantics,
352+
ReadOnly
353+
]> {
354+
let summary = "Generated op for `aten::selu : (Tensor) -> (Tensor)`";
355+
let arguments = (ins
356+
AnyTorchTensorType:$self
357+
);
358+
let results = (outs
359+
AnyTorchTensorType:$result
360+
);
361+
let hasCustomAssemblyFormat = 1;
362+
let extraClassDefinition = [{
363+
ParseResult AtenSeluOp::parse(OpAsmParser &parser, OperationState &result) {
364+
return parseDefaultTorchOp(parser, result, 1, 1);
365+
}
366+
void AtenSeluOp::print(OpAsmPrinter &printer) {
367+
printDefaultTorchOp(printer, *this, 1, 1);
368+
}
369+
}];
370+
}
371+
372+
def Torch_AtenSelu_Op : Torch_Op<"aten.selu_", [
373+
IsTrailingUnderscoreInplaceVariant,
374+
AllowsTypeRefinement
375+
]> {
376+
let summary = "Generated op for `aten::selu_ : (Tensor) -> (Tensor)`";
377+
let arguments = (ins
378+
Torch_NonValueTensorType:$self
379+
);
380+
let results = (outs
381+
Torch_NonValueTensorType:$result
382+
);
383+
let hasCustomAssemblyFormat = 1;
384+
let extraClassDefinition = [{
385+
ParseResult AtenSelu_Op::parse(OpAsmParser &parser, OperationState &result) {
386+
return parseDefaultTorchOp(parser, result, 1, 1);
387+
}
388+
void AtenSelu_Op::print(OpAsmPrinter &printer) {
389+
printDefaultTorchOp(printer, *this, 1, 1);
390+
}
391+
}];
392+
}
393+
349394
def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [
350395
AllowsTypeRefinement,
351396
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -6746,6 +6746,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67466746
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67476747
" return %0 : !torch.list<int>\n"
67486748
" }\n"
6749+
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6750+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6751+
" return %0 : !torch.list<int>\n"
6752+
" }\n"
67496753
" func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
67506754
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
67516755
" return %0 : !torch.list<int>\n"
@@ -10434,6 +10438,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1043410438
" }\n"
1043510439
" return %0#1 : !torch.int\n"
1043610440
" }\n"
10441+
" func.func @\"__torch_mlir_dtype_fn.aten.selu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
10442+
" %none = torch.constant.none\n"
10443+
" %str = torch.constant.str \"AssertionError: \"\n"
10444+
" %int11 = torch.constant.int 11\n"
10445+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10446+
" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
10447+
" torch.prim.If %1 -> () {\n"
10448+
" torch.prim.If.yield\n"
10449+
" } else {\n"
10450+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10451+
" torch.prim.If.yield\n"
10452+
" }\n"
10453+
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
10454+
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
10455+
" torch.prim.If %3 -> () {\n"
10456+
" torch.prim.If.yield\n"
10457+
" } else {\n"
10458+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10459+
" torch.prim.If.yield\n"
10460+
" }\n"
10461+
" return %0#1 : !torch.int\n"
10462+
" }\n"
1043710463
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
1043810464
" %none = torch.constant.none\n"
1043910465
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,55 @@ class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
19371937
};
19381938
} // namespace
19391939

1940+
// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
1941+
namespace {
1942+
class DecomposeAtenSeluOp : public OpRewritePattern<AtenSeluOp> {
1943+
public:
1944+
using OpRewritePattern::OpRewritePattern;
1945+
LogicalResult matchAndRewrite(AtenSeluOp op,
1946+
PatternRewriter &rewriter) const override {
1947+
Location loc = op.getLoc();
1948+
Value input = op.getSelf();
1949+
auto resType = op.getType().cast<BaseTensorType>();
1950+
if (!resType.hasDtype()) {
1951+
return rewriter.notifyMatchFailure(op, "result should have dtype");
1952+
}
1953+
1954+
// Define λ and α
1955+
double scale = 1.0507009873554804934193349852946;
1956+
double alpha = 1.6732632423543772848170429916717;
1957+
1958+
// Create constants for λ and α
1959+
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(scale));
1960+
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(alpha));
1961+
1962+
// Create zero tensor for comparison
1963+
Value constantZero =
1964+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
1965+
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
1966+
1967+
// Calculate positive and negative parts
1968+
Value constantOne =
1969+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
1970+
Value positiveOutput = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
1971+
Value minZeroX =
1972+
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
1973+
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
1974+
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
1975+
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
1976+
1977+
// Multiply the result by λ
1978+
Value seluOutput = rewriter.create<AtenAddTensorOp>(
1979+
loc, resType, positiveOutput, negativeOutput, constantOne);
1980+
seluOutput = rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
1981+
1982+
// Replace the original operation
1983+
rewriter.replaceOp(op, seluOutput);
1984+
return success();
1985+
}
1986+
};
1987+
} // namespace
1988+
19401989
namespace {
19411990
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
19421991
public:
@@ -6460,6 +6509,7 @@ class DecomposeComplexOpsPass
64606509
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
64616510
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
64626511
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
6512+
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
64636513
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
64646514
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
64656515
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
437437
target.addIllegalOp<AtenRelu6Op>();
438438
target.addIllegalOp<AtenEluOp>();
439439
target.addIllegalOp<AtenGluOp>();
440+
target.addIllegalOp<AtenSeluOp>();
440441
target.addIllegalOp<AtenHardswishOp>();
441442
target.addIllegalOp<AtenSoftplusOp>();
442443
target.addIllegalOp<AtenSiluOp>();

projects/pt1/e2e_testing/xfail_sets.py

+2
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@
486486
"ElementwiseLeakyReluModule_basic",
487487
"ElementwiseEluModule_basic",
488488
"ElementwiseEluNonDefaultModule_basic",
489+
"ElementwiseSeluModule_basic",
489490
"ElementwiseLogModule_basic",
490491
"ElementwiseNegModule_basic",
491492
"ElementwiseRsqrtModule_basic",
@@ -1115,6 +1116,7 @@
11151116
"ElementwiseRemainderScalarModule_Int_basic",
11161117
"ElementwiseRemainderScalarModule_Int_basic",
11171118
"ElementwiseRsqrtModule_basic",
1119+
"ElementwiseSeluModule_basic",
11181120
"ElementwiseSigmoidModule_basic",
11191121
"ElementwiseSignModule_basic",
11201122
"ElementwiseSqrtIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

+11
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
373373
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
374374
return upstream_shape_functions.unary(self)
375375

376+
def aten〇selu〡shape(self: List[int]) -> List[int]:
377+
return upstream_shape_functions.unary(self)
378+
376379
def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]:
377380
return upstream_shape_functions.unary(index)
378381

@@ -3066,6 +3069,14 @@ def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float
30663069
assert not is_integer_dtype(self_dtype)
30673070
return self_dtype
30683071

3072+
@check_dtype_function(
3073+
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
3074+
def aten〇selu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
3075+
self_rank, self_dtype = self_rank_dtype
3076+
assert self_dtype != torch.bool
3077+
assert not is_integer_dtype(self_dtype)
3078+
return self_dtype
3079+
30693080
@check_dtype_function(
30703081
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
30713082
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def emit_with_mutating_variants(key, **kwargs):
262262
"aten::relu6 : (Tensor) -> (Tensor)",
263263
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
264264
"aten::log : (Tensor) -> (Tensor)",
265+
"aten::selu : (Tensor) -> (Tensor)",
265266
"aten::sigmoid : (Tensor) -> (Tensor)",
266267
"aten::sign : (Tensor) -> (Tensor)",
267268
"aten::sgn : (Tensor) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

+21
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,27 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
564564
# ==============================================================================
565565

566566

567+
class ElementwiseSeluModule(torch.nn.Module):
568+
569+
def __init__(self):
570+
super().__init__()
571+
572+
@export
573+
@annotate_args([
574+
None,
575+
([-1, -1], torch.float32, True),
576+
])
577+
def forward(self, x):
578+
return torch.ops.aten.selu(x)
579+
580+
@register_test_case(module_factory=lambda: ElementwiseSeluModule())
581+
def ElementwiseSeluModule_basic(module, tu: TestUtils):
582+
module.forward(tu.rand(5, 3, low=-1, high=1))
583+
584+
585+
# ==============================================================================
586+
587+
567588
class ElementwiseSigmoidModule(torch.nn.Module):
568589

569590
def __init__(self):

0 commit comments

Comments
 (0)