|
21 | 21 | #include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
22 | 22 | #include "torch-mlir/Conversion/Utils/Utils.h"
|
23 | 23 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
| 24 | +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" |
24 | 25 | #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
25 | 26 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
26 | 27 | #include "llvm/ADT/APSInt.h"
|
27 | 28 | #include <numeric>
|
| 29 | +#include <string> |
28 | 30 | #include <type_traits>
|
29 | 31 |
|
30 | 32 | using namespace mlir;
|
@@ -3564,6 +3566,68 @@ class ConvertAtenPolarOp : public OpConversionPattern<AtenPolarOp> {
|
3564 | 3566 | };
|
3565 | 3567 | } // namespace
|
3566 | 3568 |
|
| 3569 | +namespace { |
| 3570 | +class ConvertSymConstrainRangeOp |
| 3571 | + : public OpConversionPattern<AtenSymConstrainRangeOp> { |
| 3572 | +public: |
| 3573 | + using OpConversionPattern::OpConversionPattern; |
| 3574 | + LogicalResult |
| 3575 | + matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor, |
| 3576 | + ConversionPatternRewriter &rewriter) const override { |
| 3577 | + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) |
| 3578 | + return failure(); |
| 3579 | + |
| 3580 | + auto loc = op.getLoc(); |
| 3581 | + auto min = op.getMin(); |
| 3582 | + auto max = op.getMax(); |
| 3583 | + |
| 3584 | + int64_t minValue = std::numeric_limits<int64_t>::min(); |
| 3585 | + int64_t maxValue = std::numeric_limits<int64_t>::max(); |
| 3586 | + |
| 3587 | + Type operandType = getTypeConverter()->convertType(op.getSize().getType()); |
| 3588 | + |
| 3589 | + if (!isa<Torch::NoneType>(min.getType())) |
| 3590 | + if (!matchPattern(min, m_TorchConstantInt(&minValue))) |
| 3591 | + return rewriter.notifyMatchFailure( |
| 3592 | + op, "Expected min value to be constant integer"); |
| 3593 | + |
| 3594 | + if (!isa<Torch::NoneType>(max.getType())) |
| 3595 | + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) |
| 3596 | + return rewriter.notifyMatchFailure( |
| 3597 | + op, "Expected max value to be constant integer"); |
| 3598 | + |
| 3599 | + if (maxValue < minValue) { |
| 3600 | + std::string errorMsg = |
| 3601 | + "Max must be greater than or equal to min, got min = " + |
| 3602 | + std::to_string(minValue) + ", max = " + std::to_string(maxValue); |
| 3603 | + return op.emitError(errorMsg); |
| 3604 | + } |
| 3605 | + |
| 3606 | + min = getConstant(rewriter, loc, minValue, operandType); |
| 3607 | + max = getConstant(rewriter, loc, maxValue, operandType); |
| 3608 | + |
| 3609 | + // Check min <= size <= max |
| 3610 | + |
| 3611 | + // FIXME:: Skip the below checks if constraint ops are already inserted as |
| 3612 | + // part of symbol expr evaluation |
| 3613 | + auto checkMin = rewriter.create<arith::CmpIOp>( |
| 3614 | + loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); |
| 3615 | + auto checkMax = rewriter.create<arith::CmpIOp>( |
| 3616 | + loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); |
| 3617 | + auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax); |
| 3618 | + |
| 3619 | + std::string assertMessage = "Size constraint failed. Expected range: [" + |
| 3620 | + std::to_string(minValue) + ", " + |
| 3621 | + std::to_string(maxValue) + "]"; |
| 3622 | + rewriter.create<cf::AssertOp>(loc, compareVal, |
| 3623 | + rewriter.getStringAttr(assertMessage)); |
| 3624 | + |
| 3625 | + rewriter.eraseOp(op); |
| 3626 | + return success(); |
| 3627 | + } |
| 3628 | +}; |
| 3629 | +} // namespace |
| 3630 | + |
3567 | 3631 | void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
3568 | 3632 | TypeConverter &typeConverter, RewritePatternSet &patterns,
|
3569 | 3633 | ConversionTarget &target) {
|
@@ -3626,4 +3690,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
3626 | 3690 | patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
|
3627 | 3691 | target.addIllegalOp<AtenPolarOp>();
|
3628 | 3692 | patterns.add<ConvertAtenPolarOp>(typeConverter, context);
|
| 3693 | + target.addIllegalOp<AtenSymConstrainRangeOp>(); |
| 3694 | + patterns.add<ConvertSymConstrainRangeOp>(typeConverter, context); |
3629 | 3695 | }
|
0 commit comments