Skip to content

Commit 052404a

Browse files
committed
fix(ONNX): avoids resizing unsupported dimensions
1 parent 8d1b1cc commit 052404a

File tree

2 files changed

+144
-7
lines changed

2 files changed

+144
-7
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+141-4
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,77 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
180180
return success();
181181
}
182182

183+
Value scaleIdentityComparisonOpForFactorAtDimensionIn(
184+
Value givenScaleFactors, int64_t givenDimension, OpBinder binder,
185+
ConversionPatternRewriter &rewriter) {
186+
auto typeOfScaleFactors =
187+
cast<Torch::BaseTensorType>(givenScaleFactors.getType());
188+
189+
Type typeOfSelectionFromScaleFactors =
190+
typeOfScaleFactors.getWithSizesAndDtype(
191+
ArrayRef<int64_t>{1}, typeOfScaleFactors.getOptionalDtype());
192+
193+
auto opLocation = binder.getLoc();
194+
195+
Value zeroAsOp = rewriter.create<Torch::ConstantIntOp>(
196+
opLocation, rewriter.getI64IntegerAttr(0));
197+
198+
Value scaleIdentityAsOp = rewriter.create<Torch::ConstantFloatOp>(
199+
opLocation, rewriter.getF64FloatAttr(1.0));
200+
201+
Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
202+
opLocation, rewriter.getI64IntegerAttr(givenDimension));
203+
204+
Type typeOfScaleFactor = rewriter.getType<Torch::FloatType>();
205+
206+
Value selectionFromScaleFactorsAsOp = rewriter.create<Torch::AtenSelectIntOp>(
207+
opLocation, typeOfSelectionFromScaleFactors, givenScaleFactors, zeroAsOp,
208+
givenDimensionAsOp);
209+
210+
Value scaleFactorAsOp = rewriter.create<Torch::AtenItemOp>(
211+
opLocation, typeOfScaleFactor, selectionFromScaleFactorsAsOp);
212+
213+
Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();
214+
215+
return rewriter.create<Torch::AtenEqFloatOp>(
216+
opLocation, typeOfComparisonResult, scaleFactorAsOp, scaleIdentityAsOp);
217+
}
218+
219+
Value originalSizeComparisonOpForSizeAtDimensionIn(
220+
Value givenTargetSizes, Value givenOriginalTensor, int64_t givenDimension,
221+
OpBinder binder, ConversionPatternRewriter &rewriter) {
222+
auto typeOfTargetSizes =
223+
cast<Torch::BaseTensorType>(givenTargetSizes.getType());
224+
225+
Type typeOfSelectionFromTargetSizes = typeOfTargetSizes.getWithSizesAndDtype(
226+
ArrayRef<int64_t>{1}, typeOfTargetSizes.getOptionalDtype());
227+
228+
auto opLocation = binder.getLoc();
229+
230+
Value zeroAsOp = rewriter.create<Torch::ConstantIntOp>(
231+
opLocation, rewriter.getI64IntegerAttr(0));
232+
233+
Type typeOfTargetSize = rewriter.getType<Torch::IntType>();
234+
235+
Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
236+
opLocation, rewriter.getI64IntegerAttr(givenDimension));
237+
238+
Value selectionFromTargetSizesAsOp = rewriter.create<Torch::AtenSelectIntOp>(
239+
opLocation, typeOfSelectionFromTargetSizes, givenTargetSizes, zeroAsOp,
240+
givenDimensionAsOp);
241+
242+
Value targetSizeAsOp = rewriter.create<Torch::AtenItemOp>(
243+
opLocation, typeOfTargetSize, selectionFromTargetSizesAsOp);
244+
245+
Value originalSizeAsOp = rewriter.create<Torch::AtenSizeIntOp>(
246+
opLocation, givenOriginalTensor, givenDimensionAsOp);
247+
248+
Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();
249+
250+
return rewriter.create<Torch::AtenEqIntOp>(opLocation, typeOfComparisonResult,
251+
targetSizeAsOp, originalSizeAsOp);
252+
}
253+
183254
Value withUnsupportedDimensionsFilteredOut(
184255
Value givenTransformationVector, OpBinder binder,
185256
ConversionPatternRewriter &rewriter) {
@@ -2724,6 +2795,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27242795
"round_prefer_floor") ||
27252796
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
27262797
return failure();
2798+
2799+
Value inputTensor = operands[0];
2800+
auto typeOfInputTensor =
2801+
cast<Torch::BaseTensorType>(inputTensor.getType());
2802+
2803+
auto sizesOfInputTensor = typeOfInputTensor.getSizes();
2804+
ArrayRef<int64_t> sizesOfOutputTensor = typeOfOutputTensor.getSizes();
2805+
2806+
int64_t const dimensionAssumedToBeBatch = 0;
2807+
int64_t const dimensionAssumedToBeChannel = 1;
2808+
int64_t nonResizableDimensions[] = {
2809+
dimensionAssumedToBeBatch,
2810+
dimensionAssumedToBeChannel,
2811+
};
2812+
2813+
auto unknownSize = Torch::kUnknownSize;
2814+
2815+
// Compile-time check for dimensions of static size
2816+
for (auto eachDimension : nonResizableDimensions) {
2817+
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension];
2818+
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension];
2819+
2820+
if (eachSizeOfInputTensor == unknownSize ||
2821+
eachSizeOfOutputTensor == unknownSize) {
2822+
continue;
2823+
} else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) {
2824+
continue;
2825+
}
2826+
2827+
auto scalingIntentErrorMessage =
2828+
"unsupported: non-trivial intent to scale dimension: " +
2829+
std::to_string(eachDimension);
2830+
2831+
return rewriter.notifyMatchFailure(binder.op,
2832+
scalingIntentErrorMessage);
2833+
};
2834+
27272835
if (antialias != 0) {
27282836
return rewriter.notifyMatchFailure(
27292837
binder.op,
@@ -2773,10 +2881,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27732881
rewriter.create<Torch::ConstantStrOp>(opLocation, modeStr);
27742882
}
27752883

2776-
Value inputTensor = operands[0];
2777-
auto typeOfInputTensor =
2778-
cast<Torch::BaseTensorType>(inputTensor.getType());
2779-
auto sizesOfInputTensor = typeOfInputTensor.getSizes();
27802884
unsigned rankOfInputTensor = sizesOfInputTensor.size();
27812885

27822886
// supported modes:
@@ -2824,10 +2928,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
28242928

28252929
if (operands.size() < 4) {
28262930
Value proposedScaleFactorsAsOp = operands[2];
2931+
2932+
// run-time scale factor check for dynamic sizes
2933+
for (auto eachDimension : nonResizableDimensions) {
2934+
auto eachScaleIdentityComparisonAsOp =
2935+
scaleIdentityComparisonOpForFactorAtDimensionIn(
2936+
proposedScaleFactorsAsOp, eachDimension, binder, rewriter);
2937+
2938+
auto eachErrorMessage =
2939+
"Unsupported: non-trivial scale factor for dimension " +
2940+
std::to_string(eachDimension);
2941+
2942+
rewriter.create<Torch::RuntimeAssertOp>(
2943+
opLocation, eachScaleIdentityComparisonAsOp,
2944+
rewriter.getStringAttr(eachErrorMessage));
2945+
};
2946+
28272947
filteredScaleFactorsAsOp = withUnsupportedDimensionsFilteredOut(
28282948
proposedScaleFactorsAsOp, binder, rewriter);
28292949
} else {
28302950
Value proposedSizesAsOp = operands[3];
2951+
2952+
// run-time target size check for dynamic sizes
2953+
for (auto eachDimension : nonResizableDimensions) {
2954+
auto eachSizeComparisonAsOp =
2955+
originalSizeComparisonOpForSizeAtDimensionIn(
2956+
proposedSizesAsOp, inputTensor, eachDimension, binder,
2957+
rewriter);
2958+
2959+
auto eachErrorMessage =
2960+
"Unsupported: non-trivial resizing of dimension " +
2961+
std::to_string(eachDimension);
2962+
2963+
rewriter.create<Torch::RuntimeAssertOp>(
2964+
opLocation, eachSizeComparisonAsOp,
2965+
rewriter.getStringAttr(eachErrorMessage));
2966+
};
2967+
28312968
filteredSizesAsOp = withUnsupportedDimensionsFilteredOut(
28322969
proposedSizesAsOp, binder, rewriter);
28332970
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22562256
// CHECK-LABEL: func.func @test_resize_sizes_nearest
22572257
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22582258
%none = torch.constant.none
2259-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2259+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22602260
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22612261
return %0 : !torch.vtensor<[?,?,?,?],f32>
22622262
}
@@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22672267
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22682268
%none = torch.constant.none
22692269
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2270-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2270+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22712271
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
22722272
torch.onnx.coordinate_transformation_mode = "half_pixel",
22732273
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
@@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
22802280
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
22812281
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22822282
%none = torch.constant.none
2283-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2283+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
22842284
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22852285
return %0 : !torch.vtensor<[?,?,?,?],f32>
22862286
}

0 commit comments

Comments
 (0)