Skip to content

Commit fb21a85

Browse files
authored
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where 1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W 2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W Now this has been fixed in llvm/llvm-project#73855 which broke the torch-mlir lowering to that Op. This patch switches lowering in torch-mlir to the newly introduced `linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that is compatible with PyTorch's memory layout. Fix #2622
1 parent 8252656 commit fb21a85

File tree

2 files changed

+7
-23
lines changed

2 files changed

+7
-23
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
848848
indices);
849849
};
850850

851+
// expand F,C,H,W -> G,F/G,C,H,W
851852
auto expandWeight = [&](Value tensor) {
852853
auto inType = tensor.getType().cast<RankedTensorType>();
853854
auto inShape = makeShapeTorchCompatible(inType.getShape());
@@ -868,21 +869,19 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
868869

869870
Value paddedInputExpanded = expandGroups(paddedInput, 1);
870871
Value weightExpanded = expandWeight(weight);
871-
Value outputTensorExpanded = expandGroups(outputTensor, 1);
872+
auto expandOutputTensor = expandGroups(outputTensor, 1);
872873

873874
// TODO: add 1D and 3D case
874875
conv = rewriter
875-
.create<linalg::Conv2DNgchwFgchwOp>(
876-
loc, outputTensorExpanded.getType(),
876+
.create<linalg::Conv2DNgchwGfchwOp>(
877+
loc, expandOutputTensor.getResultType(),
877878
ValueRange{paddedInputExpanded, weightExpanded},
878-
outputTensorExpanded, stridesAttr, dilationAttr)
879+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
879880
.getResult(0);
880881

881-
SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
882-
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
883-
indices.push_back({dim});
884882
conv = rewriter.create<tensor::CollapseShapeOp>(
885-
loc, outputTensor.getType(), conv, indices);
883+
loc, outputTensor.getType(), conv,
884+
expandOutputTensor.getReassociationIndices());
886885
}
887886

888887
Type newResultType = getTypeConverter()->convertType(op.getType());

projects/pt1/e2e_testing/xfail_sets.py

-15
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,6 @@
2323
"IscloseStaticModuleTrue_basic"
2424
}
2525

26-
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
27-
LINALG_XFAIL_SET |= {
28-
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
29-
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
30-
"ConvolutionModule2DGroups_basic",
31-
}
32-
33-
3426
TORCHDYNAMO_XFAIL_SET = {
3527
#### General TorchDynamo/PyTorch errors
3628

@@ -316,13 +308,6 @@
316308
"ArangeStartOutViewModule_basic",
317309
}
318310

319-
if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
320-
TORCHDYNAMO_XFAIL_SET |= {
321-
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
322-
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
323-
"ConvolutionModule2DGroups_basic",
324-
}
325-
326311
TORCHDYNAMO_CRASHING_SET = {
327312
# No upstream decompositions.
328313
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)

0 commit comments

Comments
 (0)