Skip to content

Commit 2f49240

Browse files
daveliddellDave Liddell
and
Dave Liddell
authored
[onnx] Added flatten (#2760)
[https://github.com/nod-ai/SHARK-Turbine/issues/328](url) --------- Co-authored-by: Dave Liddell <[email protected]>
1 parent b3a3ad4 commit 2f49240

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

+73
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,79 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13641364
binder.op, resultType, data, dimValueList);
13651365
return success();
13661366
});
1367+
patterns.onOp(
1368+
"Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1369+
// Flatten means to partition the input tensor's dimensions
1370+
// into a "left range" spanning 0 to axis - 1 and a "right range"
1371+
// spanning axis to rank - 1. Each range is then collapsed
1372+
// into a single dimension, resulting in a 2-D tensor.
1373+
// If either range is empty, it is replaced with a single
1374+
// dimension of size 1.
1375+
//
1376+
// For example, for a 4-D input tensor of shape (a, b, c, d)
1377+
// and axis==2, flatten produces a 2-D tensor of shape
1378+
// (a*b, c*d).
1379+
//
1380+
// If instead axis==0, the left range is empty, and the result
1381+
// is (1, a*b*c*d).
1382+
1383+
Torch::ValueTensorType resultType;
1384+
Value operand;
1385+
int64_t axis;
1386+
if (binder.tensorOperand(operand) ||
1387+
binder.s64IntegerAttr(axis, "axis", 1) ||
1388+
binder.tensorResultType(resultType))
1389+
return failure();
1390+
1391+
// If axis is negative, count from the right instead of left
1392+
int64_t rank =
1393+
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
1394+
if (axis < 0)
1395+
axis = rank + axis;
1396+
1397+
Value collapsedRight;
1398+
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
1399+
binder.op->getContext());
1400+
1401+
if (axis >= rank) {
1402+
// If the right range is empty, add a dim of size 1 to the
1403+
// right side of the shape:
1404+
// cr = torch.unsqueeze(x, x.ndim)
1405+
Value rankConst = rewriter.create<Torch::ConstantIntOp>(
1406+
binder.getLoc(), rewriter.getI64IntegerAttr(rank));
1407+
collapsedRight = rewriter.create<Torch::AtenUnsqueezeOp>(
1408+
binder.getLoc(), baseType, operand, rankConst);
1409+
} else {
1410+
// Otherwise, collapse the right range into a single dimension:
1411+
// cr = torch._prims.collapse(x, axis, x.ndim - 1)
1412+
Value axisConst = rewriter.create<Torch::ConstantIntOp>(
1413+
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
1414+
Value rankLess1Const = rewriter.create<Torch::ConstantIntOp>(
1415+
binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1));
1416+
collapsedRight = rewriter.create<Torch::PrimsCollapseOp>(
1417+
binder.getLoc(), baseType, operand, axisConst, rankLess1Const);
1418+
}
1419+
1420+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1421+
binder.getLoc(), rewriter.getI64IntegerAttr(0));
1422+
1423+
if (axis <= 0) {
1424+
// If the left range is empty, add a dim of size 1 to the
1425+
// left side of the shape:
1426+
// torch.unsqueeze(cr, 0)
1427+
rewriter.replaceOpWithNewOp<Torch::AtenUnsqueezeOp>(
1428+
binder.op, resultType, collapsedRight, zero);
1429+
return success();
1430+
}
1431+
1432+
// Otherwise, collapse the left range into a single dimension:
1433+
// torch._prims.collapse(cr, 0, axis - 1)
1434+
Value axisLess1Const = rewriter.create<Torch::ConstantIntOp>(
1435+
binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1));
1436+
rewriter.replaceOpWithNewOp<Torch::PrimsCollapseOp>(
1437+
binder.op, resultType, collapsedRight, zero, axisLess1Const);
1438+
return success();
1439+
});
13671440
patterns.onOp("Floor", 13,
13681441
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
13691442
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

+113
Original file line numberDiff line numberDiff line change
@@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m
10621062
return %0 : !torch.vtensor<[2],si64>
10631063
}
10641064

1065+
// CHECK-LABEL: @test_flatten_4d_axis_2
1066+
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1067+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1068+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1069+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1070+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1071+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1072+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1073+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
1074+
return %0 : !torch.vtensor<[6,20],f32>
1075+
}
1076+
1077+
// CHECK-LABEL: @test_flatten_4d_axis_0
1078+
func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1079+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1080+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1081+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1082+
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1083+
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1084+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
1085+
return %0 : !torch.vtensor<[1,120],f32>
1086+
}
1087+
1088+
// CHECK-LABEL: @test_flatten_4d_axis_4
1089+
func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1090+
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
1091+
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
1092+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1093+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
1094+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
1095+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32>
1096+
return %0 : !torch.vtensor<[120,1],f32>
1097+
}
1098+
1099+
// CHECK-LABEL: @test_flatten_4d_axis_negative_2
1100+
func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1101+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1102+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1103+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1104+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1105+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1106+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1107+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
1108+
return %0 : !torch.vtensor<[6,20],f32>
1109+
}
1110+
1111+
// CHECK-LABEL: @test_flatten_4d_axis_negative_1
1112+
func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1113+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
1114+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1115+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1116+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1117+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
1118+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
1119+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32>
1120+
return %0 : !torch.vtensor<[24,5],f32>
1121+
}
1122+
1123+
// CHECK-LABEL: @test_flatten_4d_axis_negative_4
1124+
func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1125+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1126+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1127+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1128+
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1129+
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1130+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
1131+
return %0 : !torch.vtensor<[1,120],f32>
1132+
}
1133+
1134+
// CHECK-LABEL: @test_flatten_2d_axis_1
1135+
func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1136+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
1137+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
1138+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
1139+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1140+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1141+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
1142+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
1143+
return %0 : !torch.vtensor<[2,3],f32>
1144+
}
1145+
1146+
// CHECK-LABEL: @test_flatten_1d_axis_0
1147+
func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1148+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1149+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1150+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1151+
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1152+
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1153+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
1154+
return %0 : !torch.vtensor<[1,2],f32>
1155+
}
1156+
1157+
// CHECK-LABEL: @test_flatten_1d_axis_negative_1
1158+
func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1159+
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1160+
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1161+
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1162+
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1163+
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1164+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
1165+
return %0 : !torch.vtensor<[1,2],f32>
1166+
}
1167+
1168+
// COM: CHECK-LABEL: @test_flatten_1d_axis_1
1169+
func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1170+
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
1171+
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
1172+
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1173+
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1174+
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
1175+
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32>
1176+
return %0 : !torch.vtensor<[2,1],f32>
1177+
}

0 commit comments

Comments
 (0)