@@ -1937,6 +1937,55 @@ class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
1937
1937
};
1938
1938
} // namespace
1939
1939
1940
+ // Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
1941
+ namespace {
1942
+ class DecomposeAtenSeluOp : public OpRewritePattern <AtenSeluOp> {
1943
+ public:
1944
+ using OpRewritePattern::OpRewritePattern;
1945
+ LogicalResult matchAndRewrite (AtenSeluOp op,
1946
+ PatternRewriter &rewriter) const override {
1947
+ Location loc = op.getLoc ();
1948
+ Value input = op.getSelf ();
1949
+ auto resType = op.getType ().cast <BaseTensorType>();
1950
+ if (!resType.hasDtype ()) {
1951
+ return rewriter.notifyMatchFailure (op, " result should have dtype" );
1952
+ }
1953
+
1954
+ // Define λ and α
1955
+ double scale = 1.0507009873554804934193349852946 ;
1956
+ double alpha = 1.6732632423543772848170429916717 ;
1957
+
1958
+ // Create constants for λ and α
1959
+ Value scaleVal = rewriter.create <Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr (scale));
1960
+ Value alphaVal = rewriter.create <Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr (alpha));
1961
+
1962
+ // Create zero tensor for comparison
1963
+ Value constantZero =
1964
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
1965
+ Value zeroTensor = createRank0Tensor (rewriter, loc, resType, constantZero);
1966
+
1967
+ // Calculate positive and negative parts
1968
+ Value constantOne =
1969
+ rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (1.0 ));
1970
+ Value positiveOutput = rewriter.create <AtenMaximumOp>(loc, resType, zeroTensor, input);
1971
+ Value minZeroX =
1972
+ rewriter.create <AtenMinimumOp>(loc, resType, zeroTensor, input);
1973
+ Value expInput = rewriter.create <AtenExpOp>(loc, resType, minZeroX);
1974
+ Value expInputMinusOne = rewriter.create <AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
1975
+ Value negativeOutput = rewriter.create <AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
1976
+
1977
+ // Multiply the result by λ
1978
+ Value seluOutput = rewriter.create <AtenAddTensorOp>(
1979
+ loc, resType, positiveOutput, negativeOutput, constantOne);
1980
+ seluOutput = rewriter.create <AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
1981
+
1982
+ // Replace the original operation
1983
+ rewriter.replaceOp (op, seluOutput);
1984
+ return success ();
1985
+ }
1986
+ };
1987
+ } // namespace
1988
+
1940
1989
namespace {
1941
1990
class DecomposeAtenTOp : public OpRewritePattern <AtenTOp> {
1942
1991
public:
@@ -6460,6 +6509,7 @@ class DecomposeComplexOpsPass
6460
6509
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
6461
6510
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
6462
6511
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
6512
+ addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
6463
6513
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
6464
6514
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
6465
6515
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
0 commit comments