|
10 | 10 | #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
11 | 11 |
|
12 | 12 | #include "../PassDetail.h"
|
| 13 | +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
14 | 15 | #include "mlir/Dialect/Math/IR/Math.h"
|
15 | 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h"
|
@@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
1378 | 1379 | }
|
1379 | 1380 | Type elementType = payloadArgs[0].getType();
|
1380 | 1381 | Value constZero =
|
1381 |
| - b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0)); |
| 1382 | + b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); |
1382 | 1383 | Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
1383 | 1384 | payloadArgs[0], constZero);
|
1384 | 1385 | return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
1385 | 1386 | }
|
| 1387 | + if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) { |
| 1388 | + if (!lrelu.getType() |
| 1389 | + .cast<ValueTensorType>() |
| 1390 | + .getDtype() |
| 1391 | + .isa<mlir::FloatType>()) { |
| 1392 | + lrelu.emitError("unimplemented: non-floating point dtype"); |
| 1393 | + return nullptr; |
| 1394 | + } |
| 1395 | + Type elementType = payloadArgs[0].getType(); |
| 1396 | + Value constZero = |
| 1397 | + b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); |
| 1398 | + Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, |
| 1399 | + payloadArgs[0], constZero); |
| 1400 | + Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero); |
| 1401 | + Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]); |
| 1402 | + Value scale = convertScalarToDtype(b, loc, operands[1], elementType); |
| 1403 | + Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale); |
| 1404 | + return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart); |
| 1405 | + } |
1386 | 1406 | if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
1387 | 1407 | if (!gelu.getType()
|
1388 | 1408 | .cast<ValueTensorType>()
|
@@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
1812 | 1832 | LogicalResult
|
1813 | 1833 | matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
1814 | 1834 | ConversionPatternRewriter &rewriter) const override {
|
1815 |
| - if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, |
| 1835 | + if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, |
1816 | 1836 | AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
1817 | 1837 | AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
1818 | 1838 | AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
@@ -2969,7 +2989,7 @@ class ConvertTorchToLinalg
|
2969 | 2989 | target.addIllegalOp<AtenBatchNormOp>();
|
2970 | 2990 | patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
2971 | 2991 | target.addIllegalOp<
|
2972 |
| - AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, |
| 2992 | + AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, |
2973 | 2993 | AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
2974 | 2994 | AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
2975 | 2995 | AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
|
0 commit comments