Skip to content

Commit c461b86

Browse files
Merge branch 'reducelogsumexp' of https://github.com/archana-ramalingam/torch-mlir into reducelogsumexp
2 parents 7ac151c + 74db4a7 commit c461b86

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

Diff for: lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,57 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
10401040
return success();
10411041
});
10421042
patterns.onOp(
1043+
"ReduceLogSumExp", 1,
1044+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1045+
Torch::ValueTensorType resultType;
1046+
Value data;
1047+
int64_t keepDims, noop_with_empty_axes;
1048+
if (binder.tensorOperandAtIndex(data, 0) ||
1049+
binder.tensorResultType(resultType) ||
1050+
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
1051+
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
1052+
0))
1053+
return failure();
1054+
1055+
// out = Log(reducesum(exp(data)))
1056+
Value castDType = rewriter.create<Torch::ConstantIntOp>(
1057+
binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7));
1058+
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
1059+
Value constFalse =
1060+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
1061+
auto size = data.getType()
1062+
.dyn_cast<Torch::ValueTensorType>()
1063+
.getOptionalSizes();
1064+
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
1065+
size, rewriter.getF64Type());
1066+
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
1067+
binder.getLoc(), f64ResultType, data, castDType,
1068+
/*non_blocking=*/constFalse, /*copy=*/constFalse,
1069+
/*memory_format=*/noneVal);
1070+
1071+
Value dataExp = rewriter.create<Torch::AtenExpOp>(
1072+
binder.getLoc(), f64ResultType, dataCast);
1073+
auto reducedSumBool = reducedSumImpl(
1074+
binder, rewriter, dataExp, f64ResultType,
1075+
/*storeValue=*/data, keepDims, noop_with_empty_axes, true);
1076+
1077+
if (failed(reducedSumBool))
1078+
return rewriter.notifyMatchFailure(
1079+
binder.op,
1080+
"Failed to perform sum operation on square of operand");
1081+
1082+
Value finalResult = rewriter.create<Torch::AtenLogOp>(
1083+
binder.getLoc(), f64ResultType, data);
1084+
1085+
Value resultDtype = Torch::getDtypeIntValueForType(
1086+
rewriter, binder.getLoc(), resultType.getDtype());
1087+
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
1088+
binder.op, resultType, finalResult, resultDtype,
1089+
/*non_blocking=*/constFalse, /*copy=*/constFalse,
1090+
/*memory_format=*/noneVal);
1091+
return success();
1092+
});
1093+
patterns.onOp(
10431094
"ReduceMean", 1,
10441095
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
10451096
Torch::ValueTensorType resultType;

Diff for: test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+76
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,82 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
10691069

10701070
// -----
10711071

1072+
// CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example
1073+
func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1074+
// CHECK: %[[INT7:.+]] = torch.constant.int 7
1075+
// CHECK: %[[NONE:.+]] = torch.constant.none
1076+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
1077+
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1078+
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
1079+
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1080+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
1081+
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1082+
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1083+
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1084+
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32>
1085+
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
1086+
return %0 : !torch.vtensor<[1,1,1],f32>
1087+
}
1088+
1089+
// -----
1090+
1091+
// CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded
1092+
func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1093+
// CHECK: %[[INT7:.+]] = torch.constant.int 7
1094+
// CHECK: %[[NONE:.+]] = torch.constant.none
1095+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
1096+
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
1097+
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1098+
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1099+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
1100+
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
1101+
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
1102+
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
1103+
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32>
1104+
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
1105+
return %0 : !torch.vtensor<[3,2],f32>
1106+
}
1107+
1108+
// -----
1109+
1110+
// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example
1111+
func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1112+
// CHECK: %[[INT7:.+]] = torch.constant.int 7
1113+
// CHECK: %[[NONE:.+]] = torch.constant.none
1114+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
1115+
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
1116+
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1117+
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1118+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
1119+
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1120+
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
1121+
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1122+
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
1123+
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
1124+
return %0 : !torch.vtensor<[3,2,1],f32>
1125+
}
1126+
1127+
// -----
1128+
1129+
// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example
1130+
func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1131+
// CHECK: %[[INT7:.+]] = torch.constant.int 7
1132+
// CHECK: %[[NONE:.+]] = torch.constant.none
1133+
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
1134+
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],si64>
1135+
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
1136+
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1137+
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
1138+
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1139+
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
1140+
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1141+
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
1142+
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
1143+
return %0 : !torch.vtensor<[3,2,1],f32>
1144+
}
1145+
1146+
// -----
1147+
10721148
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
10731149
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
10741150
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>

0 commit comments

Comments
 (0)