Skip to content

Commit

Permalink
fix(ONNX): avoids resizing unsupported dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 28, 2025
1 parent d3146c5 commit fcf8bf8
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,16 @@ Value createScalarSublist(
} // namespace

namespace TorchImageTensor {
// int64_t const /* */ batchDim = 0;
// int64_t const /**/ channelDim = 1;
int64_t const /* */ batchDim = 0;
int64_t const /**/ channelDim = 1;
int64_t const /* */ heightDim = 2;
// int64_t const /* */ widthDim = 3;
// int64_t const /* */ depthDim = 4;

SmallVector<int64_t> nonResizableDims{
batchDim,
channelDim,
};
} // namespace TorchImageTensor

void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Expand Down Expand Up @@ -2728,6 +2733,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
auto InputTensor = cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = InputTensor.getSizes();
ArrayRef<int64_t> sizesOfOutputTensor = OutputTensor.getSizes();

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto &eachDim : TorchImageTensor::nonResizableDims) {
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim];
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim];

if (eachSizeOfInputTensor == unknownSize ||
eachSizeOfOutputTensor == unknownSize)
continue;
if (eachSizeOfInputTensor == eachSizeOfOutputTensor)
continue;

auto resizingIntentErrorMessage =
"unsupported: non-trivial intent to resize dimension: " +
std::to_string(eachDim);

return rewriter.notifyMatchFailure(binder.op,
resizingIntentErrorMessage);
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2775,9 +2807,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

Value inputTensor = operands[0];
auto InputTensor = cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = InputTensor.getSizes();
auto rankOfInputTensor = sizesOfInputTensor.size();

// supported modes:
Expand Down Expand Up @@ -2819,6 +2848,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

auto numberOfOperands = operands.size();

Type Bool = rewriter.getType<Torch::BoolType>();

auto foremostSupportedDim = TorchImageTensor::heightDim;

Value supportedScaleFactors;
Expand All @@ -2828,11 +2859,58 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

if (numberOfOperands == 3) {
Value proposedScaleFactors = operands[2];

Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));

// run-time scale factor check for dynamic sizes
for (auto &eachDim : TorchImageTensor::nonResizableDims) {
Value eachProposedScaleFactor = createScalarAs<Torch::FloatType>(
loc, eachDim, proposedScaleFactors, rewriter);

Value eachScaleFactorIsIdentity =
rewriter.create<Torch::AtenEqFloatOp>(
loc, Bool, eachProposedScaleFactor, scaleIdentity);

auto errorMessageForEachDim =
"Unsupported: non-trivial scale factor for dimension " +
std::to_string(eachDim);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachScaleFactorIsIdentity,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedScaleFactors = createScalarSublist<Torch::FloatType>(
loc, proposedScaleFactors, foremostSupportedDim, rewriter);
supportedSizes = noneVal;
} else if (numberOfOperands == 4) {
Value proposedSizes = operands[3];

// run-time target size check for dynamic sizes
for (auto &eachDimAsInt : TorchImageTensor::nonResizableDims) {
Value eachDimAsValue =
rewriter.create<Torch::ConstantIntOp>(loc, eachDimAsInt);

Value eachSizeOfInputTensor = rewriter.create<Torch::AtenSizeIntOp>(
loc, inputTensor, eachDimAsValue);

Value eachProposedSize = createScalarAs<Torch::IntType>(
loc, eachDimAsInt, proposedSizes, rewriter);

Value eachProposedSizeIsTrivial =
rewriter.create<Torch::AtenEqIntOp>(loc, Bool, eachProposedSize,
eachSizeOfInputTensor);

auto errorMessageForEachDim =
"Unsupported: non-trivial resizing of dimension " +
std::to_string(eachDimAsInt);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachProposedSizeIsTrivial,
rewriter.getStringAttr(errorMessageForEachDim));
};

supportedScaleFactors = noneVal;
supportedSizes = createScalarSublist<Torch::IntType>(
loc, proposedSizes, foremostSupportedDim, rewriter);
Expand Down

0 comments on commit fcf8bf8

Please sign in to comment.