@@ -1069,6 +1069,82 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
1069
1069
1070
1070
// -----
1071
1071
1072
+ // CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example
1073
+ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example (%arg0: !torch.vtensor <[3 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[0 ],si64 >) -> !torch.vtensor <[1 ,1 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1074
+ // CHECK: %[[INT7:.+]] = torch.constant.int 7
1075
+ // CHECK: %[[NONE:.+]] = torch.constant.none
1076
+ // CHECK: %[[FALSE:.+]] = torch.constant.bool false
1077
+ // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1078
+ // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
1079
+ // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1080
+ // CHECK: %[[TRUE:.+]] = torch.constant.bool true
1081
+ // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1082
+ // CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1083
+ // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
1084
+ // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32>
1085
+ %0 = torch.operator " onnx.ReduceLogSumExp" (%arg0 , %arg1 ) {torch.onnx.keepdims = 1 : si64 } : (!torch.vtensor <[3 ,2 ,2 ],f32 >, !torch.vtensor <[0 ],si64 >) -> !torch.vtensor <[1 ,1 ,1 ],f32 >
1086
+ return %0 : !torch.vtensor <[1 ,1 ,1 ],f32 >
1087
+ }
1088
+
1089
+ // -----
1090
+
1091
+ // CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded
1092
+ func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded (%arg0: !torch.vtensor <[3 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1093
+ // CHECK: %[[INT7:.+]] = torch.constant.int 7
1094
+ // CHECK: %[[NONE:.+]] = torch.constant.none
1095
+ // CHECK: %[[FALSE:.+]] = torch.constant.bool false
1096
+ // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
1097
+ // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1098
+ // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1099
+ // CHECK: %[[TRUE:.+]] = torch.constant.bool true
1100
+ // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
1101
+ // CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
1102
+ // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
1103
+ // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32>
1104
+ %0 = torch.operator " onnx.ReduceLogSumExp" (%arg0 , %arg1 ) {torch.onnx.keepdims = 0 : si64 } : (!torch.vtensor <[3 ,2 ,2 ],f32 >, !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ],f32 >
1105
+ return %0 : !torch.vtensor <[3 ,2 ],f32 >
1106
+ }
1107
+
1108
+ // -----
1109
+
1110
+ // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example
1111
+ func.func @test_reduce_log_sum_exp_keep_dims_example (%arg0: !torch.vtensor <[3 ,2 ,2 ],f32 >, %arg1: !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1112
+ // CHECK: %[[INT7:.+]] = torch.constant.int 7
1113
+ // CHECK: %[[NONE:.+]] = torch.constant.none
1114
+ // CHECK: %[[FALSE:.+]] = torch.constant.bool false
1115
+ // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
1116
+ // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
1117
+ // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1118
+ // CHECK: %[[TRUE:.+]] = torch.constant.bool true
1119
+ // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1120
+ // CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
1121
+ // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1122
+ // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
1123
+ %0 = torch.operator " onnx.ReduceLogSumExp" (%arg0 , %arg1 ) {torch.onnx.keepdims = 1 : si64 } : (!torch.vtensor <[3 ,2 ,2 ],f32 >, !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ,1 ],f32 >
1124
+ return %0 : !torch.vtensor <[3 ,2 ,1 ],f32 >
1125
+ }
1126
+
1127
+ // -----
1128
+
1129
+ // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example
1130
+ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example (%arg0: !torch.vtensor <[3 ,2 ,2 ],si64 >, %arg1: !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1131
+ // CHECK: %[[INT7:.+]] = torch.constant.int 7
1132
+ // CHECK: %[[NONE:.+]] = torch.constant.none
1133
+ // CHECK: %[[FALSE:.+]] = torch.constant.bool false
1134
+ // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],si64>
1135
+ // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
1136
+ // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
1137
+ // CHECK: %[[TRUE:.+]] = torch.constant.bool true
1138
+ // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1139
+ // CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
1140
+ // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
1141
+ // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
1142
+ %0 = torch.operator " onnx.ReduceLogSumExp" (%arg0 , %arg1 ) {torch.onnx.keepdims = 1 : si64 } : (!torch.vtensor <[3 ,2 ,2 ],si64 >, !torch.vtensor <[1 ],si64 >) -> !torch.vtensor <[3 ,2 ,1 ],f32 >
1143
+ return %0 : !torch.vtensor <[3 ,2 ,1 ],f32 >
1144
+ }
1145
+
1146
+ // -----
1147
+
1072
1148
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
1073
1149
func.func @test_reduce_mean_negative_axes_keepdims_example (%arg0: !torch.vtensor <[3 ,2 ,2 ],f32 >) -> !torch.vtensor <[3 ,1 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 } {
1074
1150
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
0 commit comments