Skip to content

Commit 481da8d

Browse files
authored
[TOSA] : Fix float to integer cast for torch.ops.aten.to lowering. (#3946)
The behavior of float -> integer cast in PyTorch (though I haven't found the actual code implementing the cast) appears to be (based on the results produced in PyTorch): 1. round the float nearest to zero (similar to `arith.fptosi/ui`) 2. then perform the conversion Currently we only emit `tosa.cast` for this operation but as per the spec https://www.mlplatform.org/tosa/tosa_spec.html#_cast the rounding performed for float -> integer is round to nearest integer (not zero). Hence, the current TOSA lowering for `torch.ops.aten.to` produces incorrect answer.
1 parent 9dd94fb commit 481da8d

File tree

4 files changed

+87
-11
lines changed

4 files changed

+87
-11
lines changed

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
336336
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
337337
Value src, Type destType, Value &result) {
338338

339-
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
339+
TensorType srcType = dyn_cast<TensorType>(src.getType());
340+
Type srcElemTy = srcType.getElementType();
340341
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
341342

342343
// Temporarily disable checkValidityOfCast as it's currently strictly
@@ -381,6 +382,23 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
381382
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
382383
equalToZero);
383384
} else {
385+
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger()) {
386+
// for float->int conversion, tosa.cast performs round-to-nearest
387+
// torch performs round-to-zero instead
388+
// generate round-to-zero conversion prior to tosa.cast to match with
389+
// expected torch behavior
390+
auto floor = rewriter.create<tosa::FloorOp>(op->getLoc(), srcType, src);
391+
auto ceil = rewriter.create<tosa::CeilOp>(op->getLoc(), srcType, src);
392+
393+
auto zeroValue =
394+
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();
395+
396+
auto boolType = srcType.clone(rewriter.getIntegerType(1));
397+
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
398+
rewriter, op->getLoc(), boolType, zeroValue, src);
399+
src = tosa::CreateOpAndInfer<tosa::SelectOp>(
400+
rewriter, op->getLoc(), srcType, isNegative, ceil, floor);
401+
}
384402
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
385403
}
386404
return success();

projects/pt1/e2e_testing/xfail_sets.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,8 @@
17201720
"TriuIndicesNegativeOffsetModule_basic",
17211721
"BmmFloat16Module_basic",
17221722
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
1723+
"LinspaceDtypeModule_basic",
1724+
"Aten_CastLongModule_basic",
17231725
"Unfold_Module_Rank_4",
17241726
"Unfold_Module_Rank_Zero_basic",
17251727
"Unfold_Module_basic",
@@ -2627,6 +2629,7 @@
26272629
}
26282630

26292631
ONNX_XFAIL_SET = {
2632+
"ToDtypeIntFromFloatModule_basic",
26302633
# This test is expected to time out
26312634
"TimeOutModule_basic",
26322635
# Failure - cast error
@@ -3333,6 +3336,7 @@
33333336
}
33343337

33353338
FX_IMPORTER_TOSA_XFAIL_SET = {
3339+
"ScatterAddDynamicModule_basic",
33363340
"UniformModule_basic",
33373341
"UniformStaticShapeModule_basic",
33383342
"AtenFftRfft2DLastDim_basic",
@@ -3444,7 +3448,6 @@
34443448
"AtenSubFloatModule_basic",
34453449
"AtenTopKModule_basic",
34463450
"AtenTopKSmallestModule_basic",
3447-
"Aten_CastLongModule_basic",
34483451
"Aten_EmbeddingBagExample_basic",
34493452
"AvgPool1dFloatModule_basic",
34503453
"AvgPool1dIntModule_basic",
@@ -3501,7 +3504,6 @@
35013504
"ConvolutionModule2DTransposeStridedStatic_basic",
35023505
"ConvolutionModule2DTransposeStrided_basic",
35033506
"ConvolutionModule2DTranspose_basic",
3504-
"CopyWithDifferentDTypesModule_basic",
35053507
"CumsumInputDtypeInt32Module_basic",
35063508
"CumsumModule_basic",
35073509
"CumsumStaticModule_basic",
@@ -3544,7 +3546,6 @@
35443546
"ElementwiseQuantizePerTensorUIntModule_basic",
35453547
"ElementwiseSinhIntModule_basic",
35463548
"ElementwiseSinhModule_basic",
3547-
"ElementwiseToDtypeF32ToI64Module_basic",
35483549
"ElementwiseToDtypeI64ToUI8Module_basic",
35493550
"EqIntModule_basic",
35503551
"FloatImplicitModule_basic",
@@ -3577,16 +3578,13 @@
35773578
"IndexPutImpl2DNoneIndexStaticModule_basic",
35783579
"IndexPutImpl3DFloatAccumulateModule_basic",
35793580
"IndexPutImplIndexWithNoneModule_basic",
3580-
"InterpolateDynamicModule_sizes_bilinear",
3581-
"InterpolateDynamicModule_scales_recompute_bilinear",
35823581
"IntFloatModule_basic",
35833582
"IntImplicitModule_basic",
35843583
"IsFloatingPointFloat_True",
35853584
"IsFloatingPointInt_False",
35863585
"LenStrModule_basic",
35873586
"LinalgNormKeepDimComplexModule_basic",
35883587
"LinalgVectorNormComplexModule_basic",
3589-
"LinspaceDtypeModule_basic",
35903588
"LinspaceEmptyModule_basic",
35913589
"MaskedScatterStaticBasic_basic",
35923590
"MaxPool1dCeilModeTrueModule_basic",
@@ -3649,7 +3647,6 @@
36493647
"PrimMaxIntModule_basic",
36503648
"PrimMinIntDynamicModule_basic",
36513649
"PrimMinIntModule_basic",
3652-
"PrimsConvertElementTypeModule_basic",
36533650
"PrimsSqueezeEmptyDimensionsModule_basic",
36543651
"PrimsSqueezeModule_basic",
36553652
"PrimsViewOfModule_basic",
@@ -3734,8 +3731,6 @@
37343731
"TensorToInt_basic",
37353732
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
37363733
"ThresholdBackward2dMixedModule_basic",
3737-
"ToCopyWithDTypeFalsePinMemoryModule_basic",
3738-
"ToCopyWithDTypeModule_basic",
37393734
"TorchPrimLoopForLikeModule_basic",
37403735
"TorchPrimLoopWhileLikeModule_basic",
37413736
"TraceModule_empty",
@@ -4002,7 +3997,6 @@
40023997
"AtenTriuModule_basic",
40033998
"AtenTriuWithNegDiagonalModule_basic",
40043999
"AtenTriuWithPosDiagonalModule_basic",
4005-
"Aten_CastLongModule_basic",
40064000
"Aten_EmbeddingBagExample_basic",
40074001
"AvgPool1dFloatModule_basic",
40084002
"AvgPool1dIntModule_basic",
@@ -4717,6 +4711,8 @@
47174711
"ToDtypeLayoutCPUModule_basic",
47184712
"ToDtypeLayoutNoneModule_basic",
47194713
"ToDtypeLayoutStridedModule_basic",
4714+
"ToDtypeIntFromFloatModule_basic",
4715+
"ToDtypeFloatFromIntModule_basic",
47204716
"TorchPrimLoopForLikeModule_basic",
47214717
"TorchPrimLoopWhileLikeModule_basic",
47224718
"TraceModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py

+39
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,45 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
255255
module.forward(tu.randint(3, 5))
256256

257257

258+
class ToDtypeFloatFromIntModule(torch.nn.Module):
259+
def __init__(self):
260+
super().__init__()
261+
262+
@export
263+
@annotate_args([None, ([-1, -1], torch.int64, True)])
264+
def forward(self, x):
265+
return torch.ops.aten.to(
266+
x,
267+
dtype=torch.float32,
268+
)
269+
270+
271+
@register_test_case(module_factory=lambda: ToDtypeFloatFromIntModule())
272+
def ToDtypeFloatFromIntModule_basic(module, tu: TestUtils):
273+
input = torch.randint(low=-5, high=5, size=(2, 2)).to(torch.int64)
274+
module.forward(input)
275+
276+
277+
class ToDtypeIntFromFloatModule(torch.nn.Module):
278+
def __init__(self):
279+
super().__init__()
280+
281+
@export
282+
@annotate_args([None, ([-1, -1], torch.float64, True)])
283+
def forward(self, x):
284+
return torch.ops.aten.to(
285+
x,
286+
dtype=torch.int64,
287+
)
288+
289+
290+
@register_test_case(module_factory=lambda: ToDtypeIntFromFloatModule())
291+
def ToDtypeIntFromFloatModule_basic(module, tu: TestUtils):
292+
input = tu.rand(2, 2, low=-5, high=5)
293+
input[1][1] = tu.randint(1, 1) + 0.7
294+
module.forward(input)
295+
296+
258297
class TypeAsSameModule(torch.nn.Module):
259298
def __init__(self):
260299
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

+23
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,29 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten
10221022
return %0 : !torch.vtensor<[1,128],si64>
10231023
}
10241024

1025+
// -----
1026+
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt(
1027+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
1028+
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
1029+
// CHECK: %[[INT4:.*]] = torch.constant.int 4
1030+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
1031+
// CHECK: %[[NONE:.*]] = torch.constant.none
1032+
// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
1033+
// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
1034+
// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
1035+
// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor<f32>, tensor<3x5xf32>) -> tensor<3x5xi1>
1036+
// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
1037+
// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64>
1038+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64>
1039+
// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64>
1040+
func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
1041+
%int4 = torch.constant.int 4
1042+
%false = torch.constant.bool false
1043+
%none = torch.constant.none
1044+
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64>
1045+
return %0 : !torch.vtensor<[3,5],si64>
1046+
}
1047+
10251048
// -----
10261049
// CHECK-LABEL: func.func @torch.aten.gather(
10271050
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,

0 commit comments

Comments
 (0)