diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 69f86c7cca88..909022fad6e5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -234,11 +234,16 @@ Value createScalarSublist( } // namespace namespace TorchImageTensor { -// int64_t const /* */ batchDim = 0; -// int64_t const /**/ channelDim = 1; +int64_t const /* */ batchDim = 0; +int64_t const /**/ channelDim = 1; int64_t const /* */ heightDim = 2; // int64_t const /* */ widthDim = 3; // int64_t const /* */ depthDim = 4; + +SmallVector nonResizableDims{ + batchDim, + channelDim, +}; } // namespace TorchImageTensor void mlir::torch::onnx_c::populateDefaultDomainQtoZ( @@ -2728,6 +2733,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "round_prefer_floor") || binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); + + Value inputTensor = operands[0]; + auto InputTensor = cast(inputTensor.getType()); + auto sizesOfInputTensor = InputTensor.getSizes(); + ArrayRef sizesOfOutputTensor = OutputTensor.getSizes(); + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto &eachDim : TorchImageTensor::nonResizableDims) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim]; + + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) + continue; + if (eachSizeOfInputTensor == eachSizeOfOutputTensor) + continue; + + auto resizingIntentErrorMessage = + "unsupported: non-trivial intent to resize dimension: " + + std::to_string(eachDim); + + return rewriter.notifyMatchFailure(binder.op, + resizingIntentErrorMessage); + }; + if (antialias != 0) { return rewriter.notifyMatchFailure( binder.op, @@ -2775,9 +2807,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(loc, modeStr); } - Value inputTensor = operands[0]; - auto InputTensor = cast(inputTensor.getType()); - auto sizesOfInputTensor = InputTensor.getSizes(); auto rankOfInputTensor = sizesOfInputTensor.size(); // supported modes: @@ -2819,6 +2848,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto numberOfOperands = operands.size(); + Type Bool = rewriter.getType(); + auto foremostSupportedDim = TorchImageTensor::heightDim; Value supportedScaleFactors; @@ -2828,11 +2859,58 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (numberOfOperands == 3) { Value proposedScaleFactors = operands[2]; + + Value scaleIdentity = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + + // run-time scale factor check for dynamic sizes + for (auto &eachDim : TorchImageTensor::nonResizableDims) { + Value eachProposedScaleFactor = createScalarAs( + loc, eachDim, proposedScaleFactors, rewriter); + + Value eachScaleFactorIsIdentity = + rewriter.create( + loc, Bool, eachProposedScaleFactor, scaleIdentity); + + auto errorMessageForEachDim = + "Unsupported: non-trivial scale factor for dimension " + + std::to_string(eachDim); + + rewriter.create( + loc, eachScaleFactorIsIdentity, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + supportedScaleFactors = createScalarSublist( loc, proposedScaleFactors, foremostSupportedDim, rewriter); supportedSizes = noneVal; } else if (numberOfOperands == 4) { Value proposedSizes = operands[3]; + + // run-time target size check for dynamic sizes + for (auto &eachDimAsInt : TorchImageTensor::nonResizableDims) { + Value eachDimAsValue = + rewriter.create(loc, eachDimAsInt); + + Value eachSizeOfInputTensor = rewriter.create( + loc, inputTensor, eachDimAsValue); + + Value eachProposedSize = createScalarAs( + loc, eachDimAsInt, proposedSizes, rewriter); + + Value eachProposedSizeIsTrivial = + rewriter.create(loc, Bool, eachProposedSize, + eachSizeOfInputTensor); + + auto errorMessageForEachDim = + "Unsupported: non-trivial resizing of dimension " + + std::to_string(eachDimAsInt); + + rewriter.create( + loc, eachProposedSizeIsTrivial, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + supportedScaleFactors = noneVal; supportedSizes = createScalarSublist( loc, proposedSizes, foremostSupportedDim, rewriter);