Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(ONNX): avoids resizing conventionally fixed dimensions
Browse files Browse the repository at this point in the history
bjacobgordon committed Jan 9, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 76368bd commit 7aec80b
Showing 4 changed files with 66 additions and 15 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.h
Original file line number Diff line number Diff line change
@@ -84,6 +84,10 @@ class BaseTensorType : public Type {
/// Enable isa/dyn_cast for BaseTensorType.
static bool classof(Type type);

/// The element-wise comparison of each dimension/size in `that` tensor
std::vector<std::optional<bool>>
shapeComparisonAgainst(BaseTensorType that) const;

/// Return true if this type has the same sizes and dtype as the other.
bool hasSameSizesAndDtype(BaseTensorType other) const;

29 changes: 26 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
@@ -2717,6 +2717,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_blueprint =
cast<Torch::ValueTensorType>(inputTensor.getType());

std::vector<std::optional<bool>> shapeComparison =
inputTensor_blueprint.shapeComparisonAgainst(
outputTensor_blueprint);

// Comparisons of the dimensions assumed to carry the batch and channel
auto shapeComparisonForFixedDimensions =
ArrayRef(shapeComparison).take_front(2);

for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) {
if (eachDimensionComparison == std::nullopt) {
return rewriter.notifyMatchFailure(
binder.op, "Sizes for batch and channel dimensions must be "
"statically defined");
}
if (eachDimensionComparison == false) {
return rewriter.notifyMatchFailure(
binder.op,
"Unexpected intent to resize the batch/channel dimensions");
}
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
@@ -2749,9 +2775,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_blueprint =
cast<Torch::ValueTensorType>(inputTensor.getType());
ArrayRef<int64_t> inputTensor_dimensions =
inputTensor_blueprint.getSizes();
unsigned rank = inputTensor_dimensions.size();
24 changes: 24 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
@@ -217,6 +217,30 @@ static bool isValidTorchDtype(Type dtype) {
return false;
}

std::vector<std::optional<bool>>
BaseTensorType::shapeComparisonAgainst(BaseTensorType that) const {
auto this_dimensions = /**/ getSizes();
auto that_dimensions = that.getSizes();

auto this_rank = this_dimensions.size();
auto that_rank = that_dimensions.size();

assert((this_rank == that_rank) && "Ranks must match to compare dimensions");

std::vector<std::optional<bool>> runningComparison = {};
auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions);

for (auto [eachLHDimension, eachRHDimension] : dimensionPairs) {
if (eachLHDimension == kUnknownSize || eachRHDimension == kUnknownSize) {
runningComparison.push_back(std::nullopt);
} else {
runningComparison.push_back(eachLHDimension == eachRHDimension);
}
}

return runningComparison;
}

bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
24 changes: 12 additions & 12 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
@@ -2254,35 +2254,35 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// -----

// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_resize_sizes_linear
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----

0 comments on commit 7aec80b

Please sign in to comment.