diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 07e4b23a167d..cede292ec1c3 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -417,6 +417,21 @@ class ConvertAtenEmbeddingBagPaddingIdxOp }; } // namespace +static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index, + Value input, int64_t dim) { + // performs the operation : index = index % maxIndex to wrap index around + // maxIndex + Value maxIndexValue = getDimOp(b, loc, input, dim); + maxIndexValue = + b.createOrFold(loc, index.getType(), maxIndexValue); + Value isBeyondMaxIndices = b.createOrFold( + loc, arith::CmpIPredicate::sge, index, maxIndexValue); + Value wrappedIndices = + b.createOrFold(loc, index, maxIndexValue); + return b.createOrFold(loc, isBeyondMaxIndices, + wrappedIndices, index); +} + namespace { // Let's say we have an input tensor: initialized with some random values of // size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an @@ -478,7 +493,6 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}, rewriter.getContext()); - Value finalRes = rewriter .create( @@ -486,8 +500,10 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value index = rewriter.create( - loc, rewriter.getIndexType(), args[0]); + Value index = + wrapIndicesAroundMax(b, loc, args[0], input, dimInt); + index = rewriter.create( + loc, rewriter.getIndexType(), index); SmallVector indexTarget; for (unsigned i = 0; i < inputRank; i++) indexTarget.push_back(b.create(loc, i)); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c2f7d6f2a11..c512f480bd70 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4006,6 +4006,41 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op, + ConversionPatternRewriter &rewriter) { + // performs the operation : index = index % maxIndex to wrap index around + // maxIndex + + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + auto maxIndexValueMinusOne = + tosa::getConstTensor(rewriter, op, maxIndex - 1, {}).value(); + + auto indexType = dyn_cast(index.getType()); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + + auto isBeyondMaxIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne); + auto wrappedBeyondMaxIndicesQuotient = + tosa::CreateOpAndInfer(rewriter, op->getLoc(), indexType, + index, maxIndexValue) + .getResult(); + auto wrappedBeyondMaxIndicesQuotientTimesIndices = + tosa::CreateOpAndInfer(rewriter, op->getLoc(), indexType, + wrappedBeyondMaxIndicesQuotient, + maxIndexValue, /*shift=*/0) + .getResult(); + auto wrappedBeyondMaxIndices = + tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, index, + wrappedBeyondMaxIndicesQuotientTimesIndices) + .getResult(); + + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isBeyondMaxIndices, + wrappedBeyondMaxIndices, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, @@ -4049,6 +4084,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); } + int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(), + 1, std::multiplies()); + index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter); + // Get positive dim int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -7257,10 +7296,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // coord_i_n * stride[n] int32_t index = offset; int64_t coordFinder = i; + for (int64_t dim = 0; dim < outputRank; dim++) { int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1]; index += indexCoord * stride[outputRank - dim - 1]; coordFinder /= outputSize[outputRank - dim - 1]; + index = (index % selfNumElems); } targetIndicesVec.push_back(index); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1dce55f06158..666e1230f907 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -484,6 +484,7 @@ "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "AsStridedWithOffsetModule_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", @@ -880,6 +881,7 @@ "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "AsStridedWithOffsetModule_basic", "Unfold_Module_basic", "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_basic", @@ -1748,6 +1750,7 @@ "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "AsStridedWithOffsetModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseReciprocalIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index da5212e30a6c..629c64ddd904 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -1144,3 +1144,32 @@ def forward(self, x): @register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule()) def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5)) + + +# ============================================================================== + + +class AsStridedWithOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6, 60], torch.float32, True), + ] + ) + def forward(self, x): + output_size = [6, 20] + stride = [60, 1] + slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2) + squeeze = torch.ops.aten.squeeze.dim(slice, 0) + return torch.ops.aten.as_strided( + squeeze, size=output_size, stride=stride, storage_offset=360 + ) + + +@register_test_case(module_factory=lambda: AsStridedWithOffsetModule()) +def AsStridedWithOffsetModule_basic(module, tu: TestUtils): + module.forward(torch.rand(2, 6, 60)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index a3d52166385a..a2e33fc08983 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1895,22 +1895,29 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<120> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<119> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_5]], %[[VAL_7]] : (tensor<2xi32>, tensor) -> tensor<2xi1> +// CHECK: %[[VAL_9:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, tensor) -> tensor<2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_9]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2xi32>, tensor) -> tensor<2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_10]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_8]], %[[VAL_11]], %[[VAL_5]] : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> +// CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_13]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.concat %[[VAL_16]], %[[VAL_17]], %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_21]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_23]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_19]], %[[VAL_24]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_25]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_27]] : !torch.vtensor<[4,5,2],f32> // CHECK: } func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { %int2 = torch.constant.int 2 @@ -2308,6 +2315,42 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor return %2 : !torch.vtensor<[3,3],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.as_strided$offset( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 30 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<5x5xf32>) -> tensor<25xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[5, 6, 7, 7, 8, 9, 9, 10, 11]> : tensor<9xi32>}> : () -> tensor<9xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<9xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<25xf32>) -> tensor<1x25x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32> +func.func @torch.aten.as_strided$offset(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { + %int30 = torch.constant.int 30 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.as_strided %arg0, %0, %1, %int30 : !torch.vtensor<[5,5],f32>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[3,3],f32> + return %2 : !torch.vtensor<[3,3],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(