diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index 163ed63008789..9235189fb0314 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -84,6 +84,10 @@ class BaseTensorType : public Type { /// Enable isa/dyn_cast for BaseTensorType. static bool classof(Type type); + /// The element-wise comparison of each dimension/size in `that` tensor + std::vector> + shapeComparisonAgainst(BaseTensorType that) const; + /// Return true if this type has the same sizes and dtype as the other. bool hasSameSizesAndDtype(BaseTensorType other) const; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d82d9251029bc..8de0646b08482 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2717,6 +2717,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "round_prefer_floor") || binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); + + Value inputTensor = operands[0]; + Torch::ValueTensorType inputTensor_blueprint = + cast(inputTensor.getType()); + + std::vector> shapeComparison = + inputTensor_blueprint.shapeComparisonAgainst( + outputTensor_blueprint); + + // Comparisons of the dimensions assumed to carry the batch and channel + auto shapeComparisonForFixedDimensions = + ArrayRef(shapeComparison).take_front(2); + + for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) { + if (eachDimensionComparison == std::nullopt) { + return rewriter.notifyMatchFailure( + binder.op, "Sizes for batch and channel dimensions must be " + "statically defined"); + } + if (eachDimensionComparison == false) { + return rewriter.notifyMatchFailure( + binder.op, + "Unexpected intent to resize the batch/channel dimensions"); + } + }; + if (antialias != 0) { return rewriter.notifyMatchFailure( binder.op, @@ -2749,9 +2775,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: cubic coeff must be -0.75"); } - Value inputTensor = operands[0]; - Torch::ValueTensorType inputTensor_blueprint = - cast(inputTensor.getType()); ArrayRef inputTensor_dimensions = inputTensor_blueprint.getSizes(); unsigned rank = inputTensor_dimensions.size(); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index c46865ee5fed5..dd55a29b73fd2 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -217,6 +217,30 @@ static bool isValidTorchDtype(Type dtype) { return false; } +std::vector> +BaseTensorType::shapeComparisonAgainst(BaseTensorType that) const { + auto this_dimensions = /**/ getSizes(); + auto that_dimensions = that.getSizes(); + + auto this_rank = this_dimensions.size(); + auto that_rank = that_dimensions.size(); + + assert((this_rank == that_rank) && "Ranks must match to compare dimensions"); + + std::vector> runningComparison = {}; + auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions); + + for (auto [eachLHDimension, eachRHDimension] : dimensionPairs) { + if (eachLHDimension == kUnknownSize || eachRHDimension == kUnknownSize) { + runningComparison.push_back(std::nullopt); + } else { + runningComparison.push_back(eachLHDimension == eachRHDimension); + } + } + + return runningComparison; +} + bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const { return getOptionalSizes() == other.getOptionalSizes() && getOptionalDtype() == other.getOptionalDtype(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8b..ac1e82f0c6e6f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2254,35 +2254,35 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // ----- // CHECK-LABEL: func.func @test_resize_sizes_nearest - func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> + return %0 : !torch.vtensor<[1,1,?,?],f32> } // ----- // CHECK-LABEL: func.func @test_resize_sizes_nearest -func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", - torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> + return %0 : !torch.vtensor<[1,1,?,?],f32> } // ----- // CHECK-LABEL: func.func @test_resize_sizes_linear - func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], + func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> + return %0 : !torch.vtensor<[1,1,?,?],f32> } // -----