@@ -180,6 +180,77 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
180
180
return success ();
181
181
}
182
182
183
+ Value scaleIdentityComparisonOpForFactorAtDimensionIn (
184
+ Value givenScaleFactors, int64_t givenDimension, OpBinder binder,
185
+ ConversionPatternRewriter &rewriter) {
186
+ auto typeOfScaleFactors =
187
+ cast<Torch::BaseTensorType>(givenScaleFactors.getType ());
188
+
189
+ Type typeOfSelectionFromScaleFactors =
190
+ typeOfScaleFactors.getWithSizesAndDtype (
191
+ ArrayRef<int64_t >{1 }, typeOfScaleFactors.getOptionalDtype ());
192
+
193
+ auto opLocation = binder.getLoc ();
194
+
195
+ Value zeroAsOp = rewriter.create <Torch::ConstantIntOp>(
196
+ opLocation, rewriter.getI64IntegerAttr (0 ));
197
+
198
+ Value scaleIdentityAsOp = rewriter.create <Torch::ConstantFloatOp>(
199
+ opLocation, rewriter.getF64FloatAttr (1.0 ));
200
+
201
+ Value givenDimensionAsOp = rewriter.create <Torch::ConstantIntOp>(
202
+ opLocation, rewriter.getI64IntegerAttr (givenDimension));
203
+
204
+ Type typeOfScaleFactor = rewriter.getType <Torch::FloatType>();
205
+
206
+ Value selectionFromScaleFactorsAsOp = rewriter.create <Torch::AtenSelectIntOp>(
207
+ opLocation, typeOfSelectionFromScaleFactors, givenScaleFactors, zeroAsOp,
208
+ givenDimensionAsOp);
209
+
210
+ Value scaleFactorAsOp = rewriter.create <Torch::AtenItemOp>(
211
+ opLocation, typeOfScaleFactor, selectionFromScaleFactorsAsOp);
212
+
213
+ Type typeOfComparisonResult = rewriter.getType <Torch::BoolType>();
214
+
215
+ return rewriter.create <Torch::AtenEqFloatOp>(
216
+ opLocation, typeOfComparisonResult, scaleFactorAsOp, scaleIdentityAsOp);
217
+ }
218
+
219
+ Value originalSizeComparisonOpForSizeAtDimensionIn (
220
+ Value givenTargetSizes, Value givenOriginalTensor, int64_t givenDimension,
221
+ OpBinder binder, ConversionPatternRewriter &rewriter) {
222
+ auto typeOfTargetSizes =
223
+ cast<Torch::BaseTensorType>(givenTargetSizes.getType ());
224
+
225
+ Type typeOfSelectionFromTargetSizes = typeOfTargetSizes.getWithSizesAndDtype (
226
+ ArrayRef<int64_t >{1 }, typeOfTargetSizes.getOptionalDtype ());
227
+
228
+ auto opLocation = binder.getLoc ();
229
+
230
+ Value zeroAsOp = rewriter.create <Torch::ConstantIntOp>(
231
+ opLocation, rewriter.getI64IntegerAttr (0 ));
232
+
233
+ Type typeOfTargetSize = rewriter.getType <Torch::IntType>();
234
+
235
+ Value givenDimensionAsOp = rewriter.create <Torch::ConstantIntOp>(
236
+ opLocation, rewriter.getI64IntegerAttr (givenDimension));
237
+
238
+ Value selectionFromTargetSizesAsOp = rewriter.create <Torch::AtenSelectIntOp>(
239
+ opLocation, typeOfSelectionFromTargetSizes, givenTargetSizes, zeroAsOp,
240
+ givenDimensionAsOp);
241
+
242
+ Value targetSizeAsOp = rewriter.create <Torch::AtenItemOp>(
243
+ opLocation, typeOfTargetSize, selectionFromTargetSizesAsOp);
244
+
245
+ Value originalSizeAsOp = rewriter.create <Torch::AtenSizeIntOp>(
246
+ opLocation, givenOriginalTensor, givenDimensionAsOp);
247
+
248
+ Type typeOfComparisonResult = rewriter.getType <Torch::BoolType>();
249
+
250
+ return rewriter.create <Torch::AtenEqIntOp>(opLocation, typeOfComparisonResult,
251
+ targetSizeAsOp, originalSizeAsOp);
252
+ }
253
+
183
254
Value withUnsupportedDimensionsFilteredOut (
184
255
Value givenTransformationVector, OpBinder binder,
185
256
ConversionPatternRewriter &rewriter) {
@@ -2724,6 +2795,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2724
2795
" round_prefer_floor" ) ||
2725
2796
binder.f32FloatAttr (cubic_coeff_a, " cubic_coeff_a" , -0.75 ))
2726
2797
return failure ();
2798
+
2799
+ Value inputTensor = operands[0 ];
2800
+ auto typeOfInputTensor =
2801
+ cast<Torch::BaseTensorType>(inputTensor.getType ());
2802
+
2803
+ auto sizesOfInputTensor = typeOfInputTensor.getSizes ();
2804
+ ArrayRef<int64_t > sizesOfOutputTensor = typeOfOutputTensor.getSizes ();
2805
+
2806
+ int64_t const dimensionAssumedToBeBatch = 0 ;
2807
+ int64_t const dimensionAssumedToBeChannel = 1 ;
2808
+ int64_t nonResizableDimensions[] = {
2809
+ dimensionAssumedToBeBatch,
2810
+ dimensionAssumedToBeChannel,
2811
+ };
2812
+
2813
+ auto unknownSize = Torch::kUnknownSize ;
2814
+
2815
+ // Compile-time check for dimensions of static size
2816
+ for (auto eachDimension : nonResizableDimensions) {
2817
+ auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension];
2818
+ auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension];
2819
+
2820
+ if (eachSizeOfInputTensor == unknownSize ||
2821
+ eachSizeOfOutputTensor == unknownSize) {
2822
+ continue ;
2823
+ } else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) {
2824
+ continue ;
2825
+ }
2826
+
2827
+ auto scalingIntentErrorMessage =
2828
+ " unsupported: non-trivial intent to scale dimension: " +
2829
+ std::to_string (eachDimension);
2830
+
2831
+ return rewriter.notifyMatchFailure (binder.op ,
2832
+ scalingIntentErrorMessage);
2833
+ };
2834
+
2727
2835
if (antialias != 0 ) {
2728
2836
return rewriter.notifyMatchFailure (
2729
2837
binder.op ,
@@ -2773,10 +2881,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2773
2881
rewriter.create <Torch::ConstantStrOp>(opLocation, modeStr);
2774
2882
}
2775
2883
2776
- Value inputTensor = operands[0 ];
2777
- auto typeOfInputTensor =
2778
- cast<Torch::BaseTensorType>(inputTensor.getType ());
2779
- auto sizesOfInputTensor = typeOfInputTensor.getSizes ();
2780
2884
unsigned rankOfInputTensor = sizesOfInputTensor.size ();
2781
2885
2782
2886
// supported modes:
@@ -2824,10 +2928,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2824
2928
2825
2929
if (operands.size () < 4 ) {
2826
2930
Value proposedScaleFactorsAsOp = operands[2 ];
2931
+
2932
+ // run-time scale factor check for dynamic sizes
2933
+ for (auto eachDimension : nonResizableDimensions) {
2934
+ auto eachScaleIdentityComparisonAsOp =
2935
+ scaleIdentityComparisonOpForFactorAtDimensionIn (
2936
+ proposedScaleFactorsAsOp, eachDimension, binder, rewriter);
2937
+
2938
+ auto eachErrorMessage =
2939
+ " Unsupported: non-trivial scale factor for dimension " +
2940
+ std::to_string (eachDimension);
2941
+
2942
+ rewriter.create <Torch::RuntimeAssertOp>(
2943
+ opLocation, eachScaleIdentityComparisonAsOp,
2944
+ rewriter.getStringAttr (eachErrorMessage));
2945
+ };
2946
+
2827
2947
filteredScaleFactorsAsOp = withUnsupportedDimensionsFilteredOut (
2828
2948
proposedScaleFactorsAsOp, binder, rewriter);
2829
2949
} else {
2830
2950
Value proposedSizesAsOp = operands[3 ];
2951
+
2952
+ // run-time target size check for dynamic sizes
2953
+ for (auto eachDimension : nonResizableDimensions) {
2954
+ auto eachSizeComparisonAsOp =
2955
+ originalSizeComparisonOpForSizeAtDimensionIn (
2956
+ proposedSizesAsOp, inputTensor, eachDimension, binder,
2957
+ rewriter);
2958
+
2959
+ auto eachErrorMessage =
2960
+ " Unsupported: non-trivial resizing of dimension " +
2961
+ std::to_string (eachDimension);
2962
+
2963
+ rewriter.create <Torch::RuntimeAssertOp>(
2964
+ opLocation, eachSizeComparisonAsOp,
2965
+ rewriter.getStringAttr (eachErrorMessage));
2966
+ };
2967
+
2831
2968
filteredSizesAsOp = withUnsupportedDimensionsFilteredOut (
2832
2969
proposedSizesAsOp, binder, rewriter);
2833
2970
}
0 commit comments