Skip to content

Commit fd65a66

Browse files
[torch-mlir] Support lowering of aten constraint ops (#3943)
1. aten::sym_constrain_range 2. aten::sym_constrain_range_for_size 3. aten::_assert_scalar
1 parent 25aa0c6 commit fd65a66

File tree

8 files changed

+370
-1
lines changed

8 files changed

+370
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+71
Original file line numberDiff line numberDiff line change
@@ -17771,6 +17771,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_
1777117771
}];
1777217772
}
1777317773

17774+
def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [
17775+
AllowsTypeRefinement,
17776+
HasValueSemantics,
17777+
ReadOnly
17778+
]> {
17779+
let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`";
17780+
let arguments = (ins
17781+
AnyTorchScalarType:$size,
17782+
AnyTorchOptionalIntType:$min,
17783+
AnyTorchOptionalIntType:$max
17784+
);
17785+
let results = (outs
17786+
);
17787+
let hasCustomAssemblyFormat = 1;
17788+
let extraClassDefinition = [{
17789+
ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) {
17790+
return parseDefaultTorchOp(parser, result, 3, 0);
17791+
}
17792+
void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) {
17793+
printDefaultTorchOp(printer, *this, 3, 0);
17794+
}
17795+
}];
17796+
}
17797+
17798+
def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [
17799+
AllowsTypeRefinement,
17800+
HasValueSemantics,
17801+
ReadOnly
17802+
]> {
17803+
let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`";
17804+
let arguments = (ins
17805+
AnyTorchScalarType:$size,
17806+
AnyTorchOptionalIntType:$min,
17807+
AnyTorchOptionalIntType:$max
17808+
);
17809+
let results = (outs
17810+
);
17811+
let hasCustomAssemblyFormat = 1;
17812+
let extraClassDefinition = [{
17813+
ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) {
17814+
return parseDefaultTorchOp(parser, result, 3, 0);
17815+
}
17816+
void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) {
17817+
printDefaultTorchOp(printer, *this, 3, 0);
17818+
}
17819+
}];
17820+
}
17821+
17822+
def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [
17823+
AllowsTypeRefinement,
17824+
HasValueSemantics,
17825+
ReadOnly
17826+
]> {
17827+
let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`";
17828+
let arguments = (ins
17829+
AnyTorchScalarType:$self,
17830+
Torch_StringType:$assert_msg
17831+
);
17832+
let results = (outs
17833+
);
17834+
let hasCustomAssemblyFormat = 1;
17835+
let extraClassDefinition = [{
17836+
ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) {
17837+
return parseDefaultTorchOp(parser, result, 2, 0);
17838+
}
17839+
void Aten_AssertScalarOp::print(OpAsmPrinter &printer) {
17840+
printDefaultTorchOp(printer, *this, 2, 0);
17841+
}
17842+
}];
17843+
}
17844+
1777417845
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
1777517846
AllowsTypeRefinement,
1777617847
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

+66
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
2222
#include "torch-mlir/Conversion/Utils/Utils.h"
2323
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
24+
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
2425
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2526
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2627
#include "llvm/ADT/APSInt.h"
2728
#include <numeric>
29+
#include <string>
2830
#include <type_traits>
2931

3032
using namespace mlir;
@@ -3564,6 +3566,68 @@ class ConvertAtenPolarOp : public OpConversionPattern<AtenPolarOp> {
35643566
};
35653567
} // namespace
35663568

3569+
namespace {
3570+
class ConvertSymConstrainRangeOp
3571+
: public OpConversionPattern<AtenSymConstrainRangeOp> {
3572+
public:
3573+
using OpConversionPattern::OpConversionPattern;
3574+
LogicalResult
3575+
matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor,
3576+
ConversionPatternRewriter &rewriter) const override {
3577+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
3578+
return failure();
3579+
3580+
auto loc = op.getLoc();
3581+
auto min = op.getMin();
3582+
auto max = op.getMax();
3583+
3584+
int64_t minValue = std::numeric_limits<int64_t>::min();
3585+
int64_t maxValue = std::numeric_limits<int64_t>::max();
3586+
3587+
Type operandType = getTypeConverter()->convertType(op.getSize().getType());
3588+
3589+
if (!isa<Torch::NoneType>(min.getType()))
3590+
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
3591+
return rewriter.notifyMatchFailure(
3592+
op, "Expected min value to be constant integer");
3593+
3594+
if (!isa<Torch::NoneType>(max.getType()))
3595+
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
3596+
return rewriter.notifyMatchFailure(
3597+
op, "Expected max value to be constant integer");
3598+
3599+
if (maxValue < minValue) {
3600+
std::string errorMsg =
3601+
"Max must be greater than or equal to min, got min = " +
3602+
std::to_string(minValue) + ", max = " + std::to_string(maxValue);
3603+
return op.emitError(errorMsg);
3604+
}
3605+
3606+
min = getConstant(rewriter, loc, minValue, operandType);
3607+
max = getConstant(rewriter, loc, maxValue, operandType);
3608+
3609+
// Check min <= size <= max
3610+
3611+
// FIXME:: Skip the below checks if constraint ops are already inserted as
3612+
// part of symbol expr evaluation
3613+
auto checkMin = rewriter.create<arith::CmpIOp>(
3614+
loc, arith::CmpIPredicate::sle, min, adaptor.getSize());
3615+
auto checkMax = rewriter.create<arith::CmpIOp>(
3616+
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
3617+
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);
3618+
3619+
std::string assertMessage = "Size constraint failed. Expected range: [" +
3620+
std::to_string(minValue) + ", " +
3621+
std::to_string(maxValue) + "]";
3622+
rewriter.create<cf::AssertOp>(loc, compareVal,
3623+
rewriter.getStringAttr(assertMessage));
3624+
3625+
rewriter.eraseOp(op);
3626+
return success();
3627+
}
3628+
};
3629+
} // namespace
3630+
35673631
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
35683632
TypeConverter &typeConverter, RewritePatternSet &patterns,
35693633
ConversionTarget &target) {
@@ -3626,4 +3690,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
36263690
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
36273691
target.addIllegalOp<AtenPolarOp>();
36283692
patterns.add<ConvertAtenPolarOp>(typeConverter, context);
3693+
target.addIllegalOp<AtenSymConstrainRangeOp>();
3694+
patterns.add<ConvertSymConstrainRangeOp>(typeConverter, context);
36293695
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+78
Original file line numberDiff line numberDiff line change
@@ -11455,6 +11455,80 @@ class DecomposeAtenSpecialExpm1Op
1145511455
};
1145611456
} // namespace
1145711457

11458+
namespace {
11459+
class DecomposeAtenConstrainRangeForSizeOp
11460+
: public OpRewritePattern<AtenSymConstrainRangeForSizeOp> {
11461+
public:
11462+
using OpRewritePattern<AtenSymConstrainRangeForSizeOp>::OpRewritePattern;
11463+
LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op,
11464+
PatternRewriter &rewriter) const override {
11465+
11466+
auto loc = op.getLoc();
11467+
auto min = op.getMin();
11468+
auto max = op.getMax();
11469+
11470+
int64_t minValue, maxValue;
11471+
11472+
if (isa<Torch::NoneType>(min.getType())) {
11473+
// Set min value to 0
11474+
min = rewriter.create<Torch::ConstantIntOp>(loc, 0);
11475+
} else {
11476+
// Check if min value is a constant
11477+
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
11478+
return rewriter.notifyMatchFailure(
11479+
op, "Expected min value to be constant integer");
11480+
}
11481+
11482+
if (!isa<Torch::NoneType>(max.getType())) {
11483+
// Verify that max value is greater than 2
11484+
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
11485+
return rewriter.notifyMatchFailure(
11486+
op, "Expected max value to be constant integer");
11487+
11488+
if (maxValue <= 2) {
11489+
std::string errorMsg = "Max value to constrain_range_for_size must be "
11490+
"greater than 2, got: " +
11491+
std::to_string(maxValue);
11492+
return op.emitError(errorMsg);
11493+
}
11494+
}
11495+
11496+
rewriter.replaceOpWithNewOp<AtenSymConstrainRangeOp>(op, op.getSize(), min,
11497+
max);
11498+
return success();
11499+
}
11500+
};
11501+
} // namespace
11502+
11503+
namespace {
11504+
class DecomposeAten_AssertScalarOp
11505+
: public OpRewritePattern<Aten_AssertScalarOp> {
11506+
public:
11507+
using OpRewritePattern<Aten_AssertScalarOp>::OpRewritePattern;
11508+
LogicalResult matchAndRewrite(Aten_AssertScalarOp op,
11509+
PatternRewriter &rewriter) const override {
11510+
11511+
auto loc = op.getLoc();
11512+
auto assertCond = op.getSelf();
11513+
11514+
if (isa<Torch::IntType>(assertCond.getType()))
11515+
assertCond = rewriter.create<AtenBoolIntOp>(loc, assertCond);
11516+
else if (isa<Torch::FloatType>(assertCond.getType()))
11517+
assertCond = rewriter.create<AtenBoolFloatOp>(loc, assertCond);
11518+
assert(isa<Torch::BoolType>(assertCond.getType()) &&
11519+
"Unhandled type encountered in aten._assert_scalar op");
11520+
11521+
std::string assertMessage;
11522+
if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage)))
11523+
return rewriter.notifyMatchFailure(
11524+
op, "Assert message must be a constant string");
11525+
11526+
rewriter.replaceOpWithNewOp<RuntimeAssertOp>(op, assertCond, assertMessage);
11527+
return success();
11528+
}
11529+
};
11530+
} // namespace
11531+
1145811532
namespace {
1145911533
class DecomposeComplexOpsPass
1146011534
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -11753,6 +11827,10 @@ class DecomposeComplexOpsPass
1175311827
// Torchvision ops
1175411828
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);
1175511829

11830+
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
11831+
patterns);
11832+
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
11833+
1175611834
GreedyRewriteConfig config;
1175711835
config.useTopDownTraversal = true;
1175811836
config.maxIterations = GreedyRewriteConfig::kNoLimit;

projects/pt1/e2e_testing/xfail_sets.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
"Aten_TrilinearModuleZerodDimBug_basic",
3636
# missing lowering from aten.pow.Tensor_Tensor for integer result
3737
"PowIntIntModule_basic",
38+
# Unknown builtin op: aten::_check_is_size in TorchScript
39+
"AtenSymConstrainRange_basic",
40+
"AtenSymConstrainRangeForSize_basic",
41+
"Aten_AssertScalar_basic",
3842
}
3943

4044
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
@@ -623,7 +627,6 @@
623627
"AtenMmQMixedSigni8_basic",
624628
"AtenMmQint8_basic",
625629
"AtenMmQuint8_basic",
626-
"AtenNonzero1DDynamicModule_basic",
627630
"AtenRealView128Module_basic",
628631
"AtenRealView64Module_basic",
629632
"AtenTopKModule_basic",
@@ -941,6 +944,9 @@
941944
"UniformModule_basic",
942945
"UniformStaticShapeModule_basic",
943946
"ScaledDotProductAttentionGQAModule_basic",
947+
"AtenSymConstrainRange_basic",
948+
"AtenSymConstrainRangeForSize_basic",
949+
"Aten_AssertScalar_basic",
944950
}
945951

946952
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -964,6 +970,7 @@
964970
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
965971
"CrossEntropyLossModule_basic",
966972
"CrossEntropyLossNoReductionModule_basic",
973+
"AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1
967974
}
968975

969976
STABLEHLO_PASS_SET = {
@@ -3254,6 +3261,9 @@
32543261
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
32553262
"Aten_TrilinearModuleZerodDimBug_basic",
32563263
"ScaledDotProductAttentionGQAModule_basic",
3264+
"AtenSymConstrainRange_basic",
3265+
"AtenSymConstrainRangeForSize_basic",
3266+
"Aten_AssertScalar_basic",
32573267
}
32583268

32593269
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+5
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs):
12321232
)
12331233
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
12341234

1235+
# Constraint ops
1236+
emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()")
1237+
emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()")
1238+
emit("aten::_assert_scalar : (Scalar, str) -> ()")
1239+
12351240
# ==========================================================================
12361241
# `prim::` namespace.
12371242
# ==========================================================================

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

+59
Original file line numberDiff line numberDiff line change
@@ -6480,3 +6480,62 @@ def forward(self, x):
64806480
@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
64816481
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
64826482
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))
6483+
6484+
6485+
# ==============================================================================
6486+
6487+
6488+
class AtenSymConstrainRange(torch.nn.Module):
6489+
def __init__(self):
6490+
super().__init__()
6491+
6492+
@export
6493+
@annotate_args([None, ([-1], torch.int, True)])
6494+
def forward(self, x):
6495+
a = x.item()
6496+
torch.ops.aten.sym_constrain_range(a, max=5)
6497+
return a
6498+
6499+
6500+
@register_test_case(module_factory=lambda: AtenSymConstrainRange())
6501+
def AtenSymConstrainRange_basic(module, tu: TestUtils):
6502+
module.forward(torch.tensor(4))
6503+
6504+
6505+
# ==============================================================================
6506+
6507+
6508+
class AtenSymConstrainRangeForSize(torch.nn.Module):
6509+
def __init__(self):
6510+
super().__init__()
6511+
6512+
@export
6513+
@annotate_args([None, ([-1], torch.int, True)])
6514+
def forward(self, x):
6515+
a = x.item()
6516+
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
6517+
return a
6518+
6519+
6520+
@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize())
6521+
def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils):
6522+
module.forward(torch.tensor(4))
6523+
6524+
6525+
# ==============================================================================
6526+
class Aten_AssertScalar(torch.nn.Module):
6527+
def __init__(self):
6528+
super().__init__()
6529+
6530+
@export
6531+
@annotate_args([None, ([-1], torch.int, True)])
6532+
def forward(self, x):
6533+
a = x.item()
6534+
assert_msg = "Assertion failed for condition x.item() > 3"
6535+
torch.ops.aten._assert_scalar(a > 3, assert_msg)
6536+
return a
6537+
6538+
6539+
@register_test_case(module_factory=lambda: Aten_AssertScalar())
6540+
def Aten_AssertScalar_basic(module, tu: TestUtils):
6541+
module.forward(torch.tensor(4))

0 commit comments

Comments
 (0)