Skip to content

Commit 4de4d38

Browse files
authored
Initial commit of NonZero op (#2766)
1 parent b5387c0 commit 4de4d38

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
100100
rewriter.replaceOpWithNewOp<Torch::AtenLtTensorOp>(
101101
binder.op, resultType, lhs, rhs);
102102
return success();
103-
});
104-
103+
});
105104
patterns.onOp("LessOrEqual", 1,
106105
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
107106
Torch::ValueTensorType resultType;
@@ -149,6 +148,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
149148
binder.op, resultType, lhs, rhs);
150149
return success();
151150
});
151+
patterns.onOp("NonZero", 13,
152+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
153+
Torch::ValueTensorType resultType;
154+
Value operand;
155+
if (binder.tensorOperand(operand) ||
156+
binder.tensorResultType(resultType)) {
157+
return failure();
158+
}
159+
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
160+
binder.op, resultType, operand);
161+
return success();
162+
});
152163
patterns.onOp(
153164
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
154165
std::string autoPad;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,15 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
450450

451451
// -----
452452

453+
// CHECK-LABEL: func.func @test_nonzero
454+
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
455+
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
456+
%0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64>
457+
return %0 : !torch.vtensor<[3,4,5],si64>
458+
}
459+
460+
// -----
461+
453462
// CHECK-LABEL: func.func @test_or2d
454463
func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
455464
// CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>

0 commit comments

Comments
 (0)