@@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m
1062
1062
return %0 : !torch.vtensor <[2 ],si64 >
1063
1063
}
1064
1064
1065
+ // CHECK-LABEL: @test_flatten_4d_axis_2
1066
+ func.func @test_flatten_4d_axis_2 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1067
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1068
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1069
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1070
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1071
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1072
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1073
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 >
1074
+ return %0 : !torch.vtensor <[6 ,20 ],f32 >
1075
+ }
1076
+
1077
+ // CHECK-LABEL: @test_flatten_4d_axis_0
1078
+ func.func @test_flatten_4d_axis_0 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1079
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1080
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1081
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1082
+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1083
+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1084
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 >
1085
+ return %0 : !torch.vtensor <[1 ,120 ],f32 >
1086
+ }
1087
+
1088
+ // CHECK-LABEL: @test_flatten_4d_axis_4
1089
+ func.func @test_flatten_4d_axis_4 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[120 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1090
+ // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
1091
+ // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
1092
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1093
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
1094
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
1095
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 4 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[120 ,1 ],f32 >
1096
+ return %0 : !torch.vtensor <[120 ,1 ],f32 >
1097
+ }
1098
+
1099
+ // CHECK-LABEL: @test_flatten_4d_axis_negative_2
1100
+ func.func @test_flatten_4d_axis_negative_2 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1101
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1102
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1103
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1104
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1105
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1106
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1107
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 >
1108
+ return %0 : !torch.vtensor <[6 ,20 ],f32 >
1109
+ }
1110
+
1111
+ // CHECK-LABEL: @test_flatten_4d_axis_negative_1
1112
+ func.func @test_flatten_4d_axis_negative_1 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[24 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1113
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
1114
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1115
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1116
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1117
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
1118
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
1119
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[24 ,5 ],f32 >
1120
+ return %0 : !torch.vtensor <[24 ,5 ],f32 >
1121
+ }
1122
+
1123
+ // CHECK-LABEL: @test_flatten_4d_axis_negative_4
1124
+ func.func @test_flatten_4d_axis_negative_4 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1125
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1126
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1127
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1128
+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1129
+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1130
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -4 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 >
1131
+ return %0 : !torch.vtensor <[1 ,120 ],f32 >
1132
+ }
1133
+
1134
+ // CHECK-LABEL: @test_flatten_2d_axis_1
1135
+ func.func @test_flatten_2d_axis_1 (%arg0: !torch.vtensor <[2 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1136
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
1137
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
1138
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
1139
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1140
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1141
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
1142
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ],f32 >
1143
+ return %0 : !torch.vtensor <[2 ,3 ],f32 >
1144
+ }
1145
+
1146
+ // CHECK-LABEL: @test_flatten_1d_axis_0
1147
+ func.func @test_flatten_1d_axis_0 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1148
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1149
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1150
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1151
+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1152
+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1153
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 >
1154
+ return %0 : !torch.vtensor <[1 ,2 ],f32 >
1155
+ }
1156
+
1157
+ // CHECK-LABEL: @test_flatten_1d_axis_negative_1
1158
+ func.func @test_flatten_1d_axis_negative_1 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1159
+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1160
+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1161
+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1162
+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1163
+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1164
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 >
1165
+ return %0 : !torch.vtensor <[1 ,2 ],f32 >
1166
+ }
1167
+
1168
+ // COM: CHECK-LABEL: @test_flatten_1d_axis_1
1169
+ func.func @test_flatten_1d_axis_1 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[2 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1170
+ // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
1171
+ // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
1172
+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1173
+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1174
+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
1175
+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[2 ,1 ],f32 >
1176
+ return %0 : !torch.vtensor <[2 ,1 ],f32 >
1177
+ }
0 commit comments