diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 3d97b695f1ab..af3635c7639a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -336,7 +336,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result) { - Type srcElemTy = dyn_cast(src.getType()).getElementType(); + TensorType srcType = dyn_cast(src.getType()); + Type srcElemTy = srcType.getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); // Temporarily disable checkValidityOfCast as it's currently strictly @@ -381,6 +382,23 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, result = rewriter.create(op->getLoc(), destType, equalToZero); } else { + if (llvm::isa(srcElemTy) && destElemTy.isInteger()) { + // for float->int conversion, tosa.cast performs round-to-nearest + // torch performs round-to-zero instead + // generate round-to-zero conversion prior to tosa.cast to match with + // expected torch behavior + auto floor = rewriter.create(op->getLoc(), srcType, src); + auto ceil = rewriter.create(op->getLoc(), srcType, src); + + auto zeroValue = + tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); + + auto boolType = srcType.clone(rewriter.getIntegerType(1)); + auto isNegative = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, src); + src = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), srcType, isNegative, ceil, floor); + } result = rewriter.create(op->getLoc(), destType, src); } return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b53611ff1e79..740286af6f6a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1720,6 +1720,8 @@ "TriuIndicesNegativeOffsetModule_basic", "BmmFloat16Module_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "LinspaceDtypeModule_basic", + "Aten_CastLongModule_basic", "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_basic", "Unfold_Module_basic", @@ -2627,6 +2629,7 @@ } ONNX_XFAIL_SET = { + "ToDtypeIntFromFloatModule_basic", # This test is expected to time out "TimeOutModule_basic", # Failure - cast error @@ -3333,6 +3336,7 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ScatterAddDynamicModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", "AtenFftRfft2DLastDim_basic", @@ -3444,7 +3448,6 @@ "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", - "Aten_CastLongModule_basic", "Aten_EmbeddingBagExample_basic", "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", @@ -3501,7 +3504,6 @@ "ConvolutionModule2DTransposeStridedStatic_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", - "CopyWithDifferentDTypesModule_basic", "CumsumInputDtypeInt32Module_basic", "CumsumModule_basic", "CumsumStaticModule_basic", @@ -3544,7 +3546,6 @@ "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", - "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "EqIntModule_basic", "FloatImplicitModule_basic", @@ -3577,8 +3578,6 @@ "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "InterpolateDynamicModule_sizes_bilinear", - "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", "IsFloatingPointFloat_True", @@ -3586,7 +3585,6 @@ "LenStrModule_basic", "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", - "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", @@ -3649,7 +3647,6 @@ "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", - "PrimsConvertElementTypeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", @@ -3734,8 +3731,6 @@ "TensorToInt_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic", "ThresholdBackward2dMixedModule_basic", - "ToCopyWithDTypeFalsePinMemoryModule_basic", - "ToCopyWithDTypeModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", "TraceModule_empty", @@ -4002,7 +3997,6 @@ "AtenTriuModule_basic", "AtenTriuWithNegDiagonalModule_basic", "AtenTriuWithPosDiagonalModule_basic", - "Aten_CastLongModule_basic", "Aten_EmbeddingBagExample_basic", "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", @@ -4717,6 +4711,8 @@ "ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", + "ToDtypeIntFromFloatModule_basic", + "ToDtypeFloatFromIntModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", "TraceModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index df78262fff96..f8deda462905 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -255,6 +255,45 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5)) +class ToDtypeFloatFromIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.to( + x, + dtype=torch.float32, + ) + + +@register_test_case(module_factory=lambda: ToDtypeFloatFromIntModule()) +def ToDtypeFloatFromIntModule_basic(module, tu: TestUtils): + input = torch.randint(low=-5, high=5, size=(2, 2)).to(torch.int64) + module.forward(input) + + +class ToDtypeIntFromFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.to( + x, + dtype=torch.int64, + ) + + +@register_test_case(module_factory=lambda: ToDtypeIntFromFloatModule()) +def ToDtypeIntFromFloatModule_basic(module, tu: TestUtils): + input = tu.rand(2, 2, low=-5, high=5) + input[1][1] = tu.randint(1, 1) + 0.7 + module.forward(input) + + class TypeAsSameModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2d9d95082a89..b9fa41379195 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1022,6 +1022,29 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten return %0 : !torch.vtensor<[1,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { +// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[INT4:.*]] = torch.constant.int 4 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor, tensor<3x5xf32>) -> tensor<3x5xi1> +// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> +// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64> +func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64> + return %0 : !torch.vtensor<[3,5],si64> + } + // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,