diff --git a/externals/llvm-project b/externals/llvm-project index e2402615a5a7..5d6d982df61d 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit e2402615a5a76d46a433dfcc1de10b38a1263c9d +Subproject commit 5d6d982df61d16b6d498e6d59dd91c059679d3d8 diff --git a/externals/stablehlo b/externals/stablehlo index 8cd9444b78cc..b62dc66da994 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 8cd9444b78ccec3e42a4b21105a5a547c021e823 +Subproject commit b62dc66da9946b4c400c0d99c9d5bb8e04edaee6 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 15f29fbc3cab..c4f6054c0c90 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -127,11 +127,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, RankedTensorType weightTy, RankedTensorType outputTy, TypeAttr &accType); -// Temporary function to get TOSA const shape -// TODO: Remove this function when getTosaConstShape is available in -// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h -Value getTosaConstShape(PatternRewriter &rewriter, Location loc, - llvm::ArrayRef shape); } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 4ec703d892ad..ace593bf4f0a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -114,16 +114,29 @@ class ConvertAtenBinaryOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhs).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto outTy = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); Value binaryOp; - // TOSA ArithmeticRightShiftOp has a round parameter. if constexpr (std::is_same()) { + // TOSA ArithmeticRightShiftOp has a round parameter. binaryOp = rewriter.create(op->getLoc(), outTy, lhs, rhs, /*round=*/false); + } else if constexpr (std::is_same() || + std::is_same()) { + lhs = tosa::promoteType(rewriter, lhs, outTy); + rhs = tosa::promoteType(rewriter, rhs, outTy); + // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum and + // tosa.minimum + binaryOp = rewriter.create( + op->getLoc(), outTy, lhs, rhs, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); } else { binaryOp = tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); @@ -318,16 +331,25 @@ class ConvertAtenAddSubOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); - } else if (rhsType.getElementType() != rhsAlphaMulElemType) { - // right is tensor, rhsType == tensor - // right must be cast to same type as the alpha, so MulOp success - rhs = rewriter.create( - op->getLoc(), - RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs); - // reinitialize right value type to tensor - rhsType = dyn_cast(rhs.getType()); + } else { + if (rhsType.getElementType() != rhsAlphaMulElemType) { + // right is tensor, rhsType == tensor + // right must be cast to same type as the alpha, so MulOp success + rhs = rewriter.create( + op->getLoc(), + RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), + rhs); + // reinitialize right value type to tensor + rhsType = dyn_cast(rhs.getType()); + } } auto rhsTensor = rhsType ? rhs : rhsAsTensor; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto rhsTensorType = dyn_cast(rhsTensor.getType()); // Handle scalar value alpha. // It should be either f32/i32 @@ -340,11 +362,13 @@ class ConvertAtenAddSubOp : public OpConversionPattern { op, "Currently only scalar constants are supported for " "alpha in conversion to TOSA operation"); } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, alphaTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); auto mulAlphaOp = tosa::createMulOpAndCast( - rewriter, op, - rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType), - rhsTensor, alphaTensor, /*shift=*/0); + rewriter, op, rhsTensorType, rhsTensor, alphaTensor, /*shift=*/0); if (outElemTy.isInteger(64)) { // Tosa doesn't support 64-bit elementwise addition and subtraction. @@ -411,7 +435,13 @@ class ConvertAtenCompareOp : public OpConversionPattern { op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto rhsTensorTy = dyn_cast(rhsTensor.getType()); auto rhsElemTy = rhsTensorTy.getElementType(); @@ -467,9 +497,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()) { rewriter.replaceOpWithNewOp(op, resultTy, resultOp.getResult()); - } - - else { + } else { rewriter.replaceOp(op, resultOp.getResult()); } @@ -520,6 +548,11 @@ class ConvertAtenMulOp : public OpConversionPattern { rhsTensor = rhsType ? rhs : rhsAsTensor; } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + if (isa(outElemTy) || isa(outElemTy)) { auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( @@ -542,8 +575,10 @@ class ConvertAtenMulOp : public OpConversionPattern { // towards zero) for float type inputs. // This function takes in the division result between lhs and rhs rather // than takes in the original lhs and rhs tensors as parameters. -Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op, - TensorType outType, Value divResult) { +std::optional truncFloatDivWithDivResult(PatternRewriter &rewriter, + Operation *op, + TensorType outType, + Value divResult) { // To implement trunc mode for float inputs, multiply the floored abs // of the tensor with the elementwise signedness of the tensor. // div_result = lhs / rhs @@ -560,6 +595,14 @@ Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op, outType.getElementType()) .value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, one) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, zero) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, minusOne) + .failed()) + return std::nullopt; + auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), @@ -594,18 +637,21 @@ Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp, /*shift=*/0); - return truncFloatDivWithDivResult(rewriter, op, outType, divResult); + return truncFloatDivWithDivResult(rewriter, op, outType, divResult).value(); } // Function to perform division with floor rounding mode (rounding result // down) for integer type inputs. -Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, - Value lhs, Value rhs) { +std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { // To implement floor mode int input, utilize tosa::IntDivOp (trunc div // result) with the following formula elementwise: // floor_val = trunc_val - ((trunc_val * rhs != lhs) // && (sign(lhs) != sign(rhs))) + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhs).failed()) + return std::nullopt; + // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); @@ -619,6 +665,10 @@ Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, zero).failed()) + return std::nullopt; + auto boolType = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); @@ -682,6 +732,11 @@ class ConvertAtenDivOp : public OpConversionPattern { "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto outType = cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); @@ -718,7 +773,8 @@ class ConvertAtenDivOp : public OpConversionPattern { } else if (roundMode.compare("trunc") == 0) { // "trunc": rounds the results of the division towards zero. Equivalent // to C-style integer division. - result = truncFloatDivWithDivResult(rewriter, op, outType, divResult); + result = truncFloatDivWithDivResult(rewriter, op, outType, divResult) + .value(); } else { // None: No rounding mode result = divResult.getResult(); @@ -727,7 +783,7 @@ class ConvertAtenDivOp : public OpConversionPattern { if (roundMode.compare("floor") == 0) { // "floor": rounds the results of the division down. Equivalent to floor // division in Python (the // operator). - result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor); + result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor).value(); } else { // "trunc": rounds the results of the division towards zero. Equivalent // to C-style integer division. @@ -815,12 +871,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); } + + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), clampIn, rewriter.getI64IntegerAttr(clampMin), rewriter.getI64IntegerAttr(std::numeric_limits::max()), rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); + rewriter.getF32FloatAttr(std::numeric_limits::max()), + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); return success(); } @@ -843,10 +902,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), alphaTensor, self) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()) .value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), zero, self).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -1131,10 +1198,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI32Type()); auto reduceDimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim); + + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax return rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(outputReduceTy), - input, reduceDimAttr) + .create( + op->getLoc(), getTypeConverter()->convertType(outputReduceTy), + input, reduceDimAttr, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); }; @@ -1340,6 +1410,11 @@ class ConvertAtenPowOp : public OpConversionPattern { expTensor = tosa::promoteType(rewriter, expTensor, outType); } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), selfTensor, expTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto powOp = tosa::createBinaryOpAndCast( rewriter, op, outType, selfTensor, expTensor); rewriter.replaceOp(op, powOp.getResult()); @@ -2053,6 +2128,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto bias = adaptor.getBias(); auto biasTy = bias.getType(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, bias).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + // TOSA does not mandate that elementwise op tensors need to be ranked. if (!isa(biasTy) && !isa(biasTy)) return rewriter.notifyMatchFailure( @@ -2151,6 +2230,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*checkForUnity=*/true))) return failure(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, otherTensor) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, alphaTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto multTensor = rewriter.create(op->getLoc(), resultTy, self, alphaTensor, /*shift=*/0); @@ -2458,9 +2544,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, - Type outType, Value input, Value variance, Value eps, - Value mean, Value weight, Value bias) { +std::optional computeBatchNorm(Operation *op, + ConversionPatternRewriter &rewriter, + Type outType, Value input, Value variance, + Value eps, Value mean, Value weight, + Value bias) { // For PyTorch: // scale = gamma = weight // offset = beta = bias @@ -2484,6 +2572,15 @@ Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, // op5 = mul(op4, bscale) // op6 = add(op5, boffset) + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, mean).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, variance) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, eps).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, weight) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, bias).failed()) + return std::nullopt; + auto op1SubInputMean = rewriter.create(op->getLoc(), outType, input, mean); @@ -2592,7 +2689,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, - epsilonConst, meanVal, weightVal, biasVal); + epsilonConst, meanVal, weightVal, biasVal) + .value(); rewriter.replaceOp(op, {batchNorm}); @@ -2612,11 +2710,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // eventually being reshaped for broadcasting. // Not a ranked tensor output - if (!dyn_cast(adaptor.getInput().getType())) + auto input = adaptor.getInput(); + auto inputType = dyn_cast(input.getType()); + + if (!inputType) return rewriter.notifyMatchFailure( op, "Only ranked tensor types are supported"); - auto inputType = cast(adaptor.getInput().getType()); if (inputType.getRank() > 4) return rewriter.notifyMatchFailure(op, "Only up to 4D tensors are supported"); @@ -2626,13 +2726,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Note: cudnn_enabled is not handled. // FIXME: Handle the None cases for the optional parameters. - if (isa(adaptor.getWeight().getType())) + auto weight = adaptor.getWeight(); + if (isa(weight.getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); - if (isa(adaptor.getBias().getType())) + + auto bias = adaptor.getBias(); + if (isa(bias.getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); - auto weightType = cast(adaptor.getWeight().getType()); - auto biasType = cast(adaptor.getBias().getType()); + auto weightType = cast(weight.getType()); + auto biasType = cast(bias.getType()); int64_t inputRank = inputType.getRank(); Type elemTy = inputType.getElementType(); SmallVector inputTypeShape( @@ -2697,6 +2800,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, elemCntRcp) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + // Broadcast type and shape for various intermediate values. SmallVector bcastOutShape; for (auto en : llvm::enumerate(inputTypeShape)) { @@ -2708,14 +2816,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(makeShapeLLVMCompatible(bcastOutShape), elemTy); // Compute mean. - Value sum = computeSumAndReshape(adaptor.getInput(), inputType, bcastOutType, - bcastOutShape); + Value sum = + computeSumAndReshape(input, inputType, bcastOutType, bcastOutShape); Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, elemCntRcp, /*shift=*/0); // Compute variance. - Value squareSumSub = rewriter.create( - op.getLoc(), inputType, adaptor.getInput(), meanVal); + Value squareSumSub = + rewriter.create(op.getLoc(), inputType, input, meanVal); Value squareSum = rewriter.create(op.getLoc(), inputType, squareSumSub, squareSumSub, 0); @@ -2736,11 +2844,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( makeShapeLLVMCompatible(weightAndBiasBcastShape), elemTy); Value weightVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(), + op.getLoc(), weightAndMeanBcastType, weight, rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); Value biasVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, adaptor.getBias(), + op.getLoc(), weightAndMeanBcastType, bias, rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); double eps; @@ -2752,9 +2860,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); // Compute layer norm. - auto layerNorm = - computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, - epsilonConst, meanVal, weightVal, biasVal); + auto layerNorm = computeBatchNorm(op, rewriter, outType, input, varianceVal, + epsilonConst, meanVal, weightVal, biasVal) + .value(); rewriter.replaceOp(op, {layerNorm, meanVal, varianceVal}); @@ -2974,6 +3082,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ln2Shape, outType.getElementType()) .value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, ln2Op).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); @@ -3017,6 +3129,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only scalar constant is supported for value"); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, threshold) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, value).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto cmpOp = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), @@ -3178,8 +3296,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -static Value approximateErfOp(ConversionPatternRewriter &rewriter, - Operation *op, Value x, Type dtype) { +static std::optional +approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x, + Type dtype) { // Using: // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = @@ -3192,26 +3311,34 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto absX = rewriter.create(loc, outType, x); auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); + auto a2 = + tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); + auto a3 = + tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); + auto a4 = + tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a1).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a2).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a3).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed()) + return std::nullopt; + auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = - tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = - tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = - tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -3233,10 +3360,22 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, return rewriter.create(loc, outType, cond, erf, negateErf); } -static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, - Operation *op, Value x, Type dtype) { +static std::optional +buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, + Type dtype) { auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + // rsqrt of 2 + auto rsqrt2 = + tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, oneHalf).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, rsqrt2).failed()) + return std::nullopt; auto loc = op->getLoc(); @@ -3244,16 +3383,11 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, auto outType = x.getType(); auto mean = zero; Value xMinusMean = rewriter.create(loc, outType, x, mean); - // rsqrt of 2 - Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = approximateErfOp(rewriter, op, erfArg, dtype); + Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value(); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = - tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -3290,7 +3424,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (approximate.compare("none") == 0) { // GELU(x) = x * CDF(x) - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); + Value cdf = + buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy).value(); cdf = rewriter.createOrFold( op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); @@ -3388,7 +3523,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3418,15 +3554,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); Value negOneHalf = tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); - Value inputSquared = rewriter.create( - loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); + + if (mlir::tosa::EqualizeRanks(rewriter, loc, self, kAlphaHalf).failed() || + mlir::tosa::EqualizeRanks(rewriter, loc, self, negOneHalf).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + Value inputSquared = + rewriter.create(loc, selfType, self, self, /*shift=*/0); Value negHalfInputSquared = rewriter.create( loc, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); - Value dinputInput = rewriter.create( - loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0); + Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value(); + Value dinputInput = + rewriter.create(loc, selfType, dinput, self, /*shift=*/0); Value dinputInputAlpha = rewriter.create( loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); Value cdfExt = @@ -3445,7 +3587,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType) { return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3465,7 +3608,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value gradOutput = adaptor.getGradOutput(); - auto gradOutputType = dyn_cast(adaptor.getSelf().getType()); + auto gradOutputType = dyn_cast(gradOutput.getType()); Type gradOutputElemType = gradOutputType.getElementType(); @@ -3490,17 +3633,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, minVal) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, maxVal) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, gradOutput) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, replace).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - minVal, adaptor.getSelf()); + minVal, self); Value greater = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - adaptor.getSelf(), maxVal); + self, maxVal); Value cmp = rewriter.create( op.getLoc(), @@ -3708,11 +3862,23 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); - Value reduceOp = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), - selfElemType), - self, dimAttr); + Value reduceOp; + if constexpr (std::is_same() || + std::is_same()) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min + // and tosa.reduce_max + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), + selfElemType), + self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), + selfElemType), + self, dimAttr); + } // To handle ReduceMinDim indices, we apply ArgMaxOp on the negate // of the input tensor, which will return indices of input's min values @@ -3721,17 +3887,19 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value negateOp = rewriter.create(op->getLoc(), selfType, self); + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - negateOp, dimAttr); + negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); } else { + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - self, dimAttr); + self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); } if (argMaxOp.getType() != indicesType) { @@ -4249,13 +4417,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -Value wrapNegativeIndices(Value index, int maxIndex, Operation *op, - ConversionPatternRewriter &rewriter) { +std::optional wrapNegativeIndices(Value index, int maxIndex, + Operation *op, + ConversionPatternRewriter &rewriter) { auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); auto maxIndexValue = tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), index, zeroValue) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), index, maxIndexValue) + .failed()) + return std::nullopt; + auto indexType = dyn_cast(index.getType()); auto wrappedIndicesOp = tosa::CreateOpAndInfer( @@ -4335,7 +4510,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, - rewriter); + rewriter) + .value(); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -4504,7 +4680,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } index = - wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter); + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter) + .value(); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indicesShape; @@ -4772,19 +4949,33 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - auto condType = dyn_cast(adaptor.getCondition().getType()); + op, "Only tensor types inputs are currently supported"); + + auto cond = adaptor.getCondition(); + auto condType = dyn_cast(cond.getType()); if (!condType) return rewriter.notifyMatchFailure( - op, "Only tensor types condition are currently supported"); + op, "Only tensor types conditions are currently supported"); - auto outType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, outType, adaptor.getCondition(), adaptor.getSelf(), - adaptor.getOther()); + auto other = adaptor.getOther(); + auto otherType = dyn_cast(other.getType()); + if (!otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types inputs are currently supported"); + + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, self).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + rewriter.replaceOpWithNewOp(op, outType, cond, self, other); return success(); } @@ -4805,8 +4996,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: equal_nan is expected to be false"); // check tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - auto otherType = dyn_cast(adaptor.getOther().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + auto other = adaptor.getOther(); + auto otherType = dyn_cast(other.getType()); if (!selfType || !otherType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -4818,20 +5011,31 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "unimplemented: only FP element type is supported"); } + auto rtolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto atolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rtolConstOp) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, atolConstOp) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + // Reinitialize selfType and otherType after equalizing ranks + selfType = dyn_cast(self.getType()); + otherType = dyn_cast(other.getType()); - auto rhsSubOp = rewriter.create( - op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther()); + auto rhsSubOp = + rewriter.create(op->getLoc(), selfType, self, other); auto rhsAbsOp = rewriter.create(op->getLoc(), selfType, rhsSubOp); - auto lhsAbsOp = - rewriter.create(op->getLoc(), otherType, adaptor.getOther()); - auto rtolConstOp = - tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto lhsAbsOp = rewriter.create(op->getLoc(), otherType, other); auto mulOp = rewriter.create(op->getLoc(), otherType, rtolConstOp, lhsAbsOp, /*shift=*/0); - auto atolConstOp = - tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); auto addOp = rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); @@ -4895,9 +5099,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "max attr should be a torch constant"); } + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp auto outType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), - min_int, max_int, min_fp, max_fp); + rewriter.replaceOpWithNewOp( + op, outType, adaptor.getSelf(), min_int, max_int, min_fp, max_fp, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); return success(); } @@ -4992,13 +5198,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( }); } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, min).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, max).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + self = tosa::promoteType(rewriter, self, resultType); + min = tosa::promoteType(rewriter, min, resultType); + max = tosa::promoteType(rewriter, max, resultType); + // max(xi, min_valuei) - auto minThresholdCheck = tosa::createBinaryOpAndCast( - rewriter, op, resultType, self, min); + // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum + auto minThresholdCheck = rewriter.create( + op->getLoc(), resultType, self, min, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); // yi = min(max(xi, min_valuei), max_valuei) - auto result = tosa::createBinaryOpAndCast( - rewriter, op, resultType, minThresholdCheck, max); + // Use default NaN Propagation mode "PROPAGATE" for tosa.minimum + auto result = rewriter.create( + op->getLoc(), resultType, minThresholdCheck, max, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); rewriter.replaceOp(op, result); return success(); @@ -5339,6 +5558,11 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { op, "Only ranked tensor types supported in TOSA Remainder/Fmod"); } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, otherTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + constexpr bool isRemainderOp = std::is_same() || std::is_same() || @@ -5358,7 +5582,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { divTensor = rewriter.create(op.getLoc(), outType, divTensor); } else { - divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor); + divTensor = + floorIntDiv(rewriter, op, outType, self, otherTensor).value(); } } else { // torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b @@ -5493,9 +5718,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { std::is_same::value, "Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp"); if constexpr (std::is_same::value) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.max_pool2d pooledOutput = rewriter - .create(op->getLoc(), outputTy, input, kernel, - stride, pad) + .create( + op->getLoc(), outputTy, input, kernel, stride, pad, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); } else if constexpr (std::is_same::value) { TypeAttr accType; @@ -6086,7 +6313,8 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { } // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( op, @@ -6118,8 +6346,13 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { RankedTensorType::get(rhsTensorType.getShape(), outElemTy), rhsTensor); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getMask(), - rhsTensor, adaptor.getSelf()); + rhsTensor, self); return success(); } }; @@ -6197,12 +6430,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( translatePadsList.push_back(highPadding[i]); } - DenseElementsAttr paddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2 * rank}, rewriter.getI64Type()), - translatePadsList); - - Value padsList1 = rewriter.create( - loc, paddingAttr.getType(), paddingAttr); + Value padsList1 = tosa::getTosaConstShape(rewriter, loc, translatePadsList); Value padValue = adaptor.getValue(); Operation *padOp = padValue.getDefiningOp(); @@ -6289,6 +6517,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + rewriter.replaceOpWithNewOp(op, resultType, self, oneHalf); return success(); } @@ -6572,6 +6804,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Invalid integer width"); }); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, trilMask) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, /*shift=*/0); @@ -6653,6 +6890,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto two = tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto floorInput = rewriter.create(op->getLoc(), resultTy, self); @@ -6847,6 +7090,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Invalid integer width"); }); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, diagonalMask) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + Value diagonalTensor = rewriter.create( op->getLoc(), transposedInputType, selfTransposed, diagonalMask, /*shift=*/0); @@ -7200,6 +7448,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( self = tosa::promoteType(rewriter, self, resultType); grad = tosa::promoteType(rewriter, grad, resultType); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, grad).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto result = rewriter.create(op->getLoc(), resultType, cond.getResult(), zero, grad); @@ -8107,6 +8360,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto zi = self; // Clamp input to [eps, 1 - eps] when eps is not None + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp if (!isEpsNone) { zi = rewriter .create( @@ -8114,13 +8368,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI64IntegerAttr(static_cast(eps)), rewriter.getI64IntegerAttr(static_cast(1 - eps)), rewriter.getF32FloatAttr(static_cast(eps)), - rewriter.getF32FloatAttr(static_cast(1 - eps))) + rewriter.getF32FloatAttr(static_cast(1 - eps)), + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); } auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto oneMinusZi = rewriter.create(op->getLoc(), resultType, one, zi); @@ -8168,6 +8427,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto addOp = rewriter.create(op->getLoc(), resultType, self, one); @@ -8209,14 +8472,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto ten = tosa::getConstTensor(rewriter, op, 10.0f, {}, resultElemTy) .value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, ten).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto logOfSelf = rewriter.create(op->getLoc(), resultType, self); - auto constType = RankedTensorType::get({}, resultElemTy); + auto constTenType = RankedTensorType::get( + dyn_cast(ten.getType()).getShape(), resultElemTy); - auto logOfTen = rewriter.create(op->getLoc(), constType, ten); + auto logOfTen = rewriter.create(op->getLoc(), constTenType, ten); auto reciprocalOp = rewriter.create( - op->getLoc(), constType, logOfTen.getResult()); + op->getLoc(), constTenType, logOfTen.getResult()); auto result = rewriter.create( op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), @@ -8258,6 +8526,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto expOp = rewriter.create(op->getLoc(), resultType, self); auto result = rewriter.create(op->getLoc(), resultType, diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 9dedf457096a..ffbc75ecd5c7 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -351,7 +351,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> // tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix. - auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( + Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); @@ -378,13 +378,18 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, if (!flattenedCoeffValue) return std::nullopt; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp, + flattenedCoeffValue.value()) + .failed()) + return std::nullopt; + // Multiply the coefficients by the coordinates // %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>, // tensor<3xi32>) -> tensor<8x3xi32> auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); // Sum up the products of the coefficients and coordinates // %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) -> @@ -616,7 +621,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]] // %11 = "tosa.reshape"(%8) {new_shape = array} : (tensor<3x2xi32>) // -> tensor<3x2xi32> - auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( + Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); @@ -643,6 +648,11 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, if (!flattenedCoeffValue) return std::nullopt; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp, + flattenedCoeffValue.value()) + .failed()) + return std::nullopt; + // Multiply the coefficients by the coordinates. // [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]] // %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, @@ -650,7 +660,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); // Sum up the products of the coefficients and coordinates // [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]] @@ -734,10 +744,20 @@ std::optional convertReduceOpCommon( RankedTensorType reduce_type = RankedTensorType::get(shape_vec, reduce_element_type); - auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, - val, axis_attr); + Value reduce_op; + if constexpr (std::is_same() || + std::is_same()) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min + // and tosa.reduce_max + reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), reduce_type, val, axis_attr, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, + val, axis_attr); + } - val = reduce_op.getResult(); + val = reduce_op; } if (is_quantized) { @@ -973,6 +993,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), val.value(), + div_const) + .failed()) + return std::nullopt; + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, val.value(), div_const, 0) .getResult(); @@ -1021,6 +1047,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } + Value ordValRank0 = ordVal; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input_value, ordVal) + .failed()) + return std::nullopt; + if (fabs(ordLiteralFloat) < epsilon || fabs(static_cast(ordLiteralInt)) < epsilon) { op->emitOpError("unimplemented: L0 norm"); @@ -1049,9 +1080,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, rewriter, op, output_type, powVal, axes_elems, keep_dims); if (!result) return std::nullopt; - auto reciprocalVal = CreateOpAndInfer( - rewriter, op->getLoc(), ordVal.getType(), ordVal) - .getResult(); + + Value reciprocalVal = + CreateOpAndInfer(rewriter, op->getLoc(), + ordValRank0.getType(), ordValRank0) + .getResult(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), result.value(), + reciprocalVal) + .failed()) + return std::nullopt; + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, result.value(), reciprocalVal) .getResult(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1ed360ddae61..3e4e6089389a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -8,7 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project namespace mlir { @@ -301,31 +302,31 @@ std::optional getConstTensor(PatternRewriter &rewriter, (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && dest.isFloat8E4M3()) || - (src.isF32() && dest.isFloat8E5M2()) || + (src.isF32() && isa(dest)) || + (src.isF32() && isa(dest)) || // f16 -> * (src.isF16() && dest.isInteger(32)) || (src.isF16() && dest.isInteger(16)) || (src.isF16() && dest.isInteger(8)) || (src.isF16() && dest.isBF16()) || (src.isF16() && dest.isF32()) || - (src.isF16() && dest.isFloat8E4M3()) || - (src.isF16() && dest.isFloat8E5M2()) || + (src.isF16() && isa(dest)) || + (src.isF16() && isa(dest)) || // bf16 -> * (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isF32()) || - (src.isBF16() && dest.isFloat8E4M3()) || - (src.isBF16() && dest.isFloat8E5M2()) || + (src.isBF16() && isa(dest)) || + (src.isBF16() && isa(dest)) || // fp8e4m3 -> * - (src.isFloat8E4M3() && dest.isBF16()) || - (src.isFloat8E4M3() && dest.isF32()) || - (src.isFloat8E4M3() && dest.isF16()) || + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16()) || // fp8e5m2 -> * - (src.isFloat8E5M2() && dest.isBF16()) || - (src.isFloat8E5M2() && dest.isF32()) || - (src.isFloat8E5M2() && dest.isF16())) { + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16())) { return success(); } // clang-format on @@ -393,6 +394,11 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto boolType = srcType.clone(rewriter.getIntegerType(1)); auto isNegative = tosa::CreateOpAndInfer( rewriter, op->getLoc(), boolType, zeroValue, src); @@ -488,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && outputElemTy.isInteger(48)) { accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); - } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && - outputElemTy.isF16()) || - (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && - outputElemTy.isF16())) { + } else if ((isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16()) || + (isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16())) { accType = mlir::TypeAttr::get(rewriter.getF16Type()); } else { accType = mlir::TypeAttr::get(outputElemTy); @@ -500,17 +506,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, return success(); } -// Temporary function to get TOSA const shape -// TODO: Remove this function when getTosaConstShape is available in -// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h -Value getTosaConstShape(PatternRewriter &rewriter, Location loc, - llvm::ArrayRef shape) { - auto attr = rewriter.getIndexTensorAttr(shape); - auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); - mlir::Operation *mlir_op = - rewriter.create(loc, type, attr); - return mlir_op->getResult(0); -} - } // namespace tosa } // namespace mlir diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c0984efffd9c..7f80e84044df 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (isa(inputType)) return rewriter.getF64Type(); - if (inputType.isFloat8E5M2()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FN()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E5M2FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); if (inputType.isInteger(8)) // this is an intentional deviation from CUDA (which accumulates i8 to i64) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2993ae76b547..49a862ac7756 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -45,12 +45,14 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_6]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_6]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_1]], %[[VAL_8]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 1.000000e-01 @@ -157,14 +159,15 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.add$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -177,14 +180,15 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.sub$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -227,6 +231,35 @@ func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @torch.aten.rsqrt$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.rsqrt %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_mean_dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.08420217E-19> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_7]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?],f32> +// CHECK: } func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %dim0 = torch.constant.int 0 %reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list @@ -262,21 +295,24 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // ----- // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> -// CHECK: %[[ARG1:.*]] = torch.constant.float 2.000000e+00 -// CHECK: %[[ARG2:.*]] = torch.constant.int -1 -// CHECK: %[[ARG3:.*]] = torch.constant.bool true -// CHECK: %[[ARG4:.*]] = torch.constant.none -// CHECK: %[[ARG5:.*]] = torch.prim.ListConstruct %[[ARG2]] : (!torch.int) -> !torch.list -// CHECK: %[[ARG6:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[ARG7:.*]] = tosa.abs %[[ARG0_BUILTIN]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> -// CHECK: %[[ARG8:.*]] = tosa.pow %[[ARG7]], %[[ARG6]] : (tensor<3x151x64xf32>, tensor) -> tensor<3x151x64xf32> -// CHECK: %[[ARG9:.*]] = tosa.reduce_sum %[[ARG8]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> -// CHECK: %[[ARG10:.*]] = tosa.reciprocal %[[ARG6]] : (tensor) -> tensor -// CHECK: %[[ARG11:.*]] = tosa.pow %[[ARG9]], %[[ARG10]] : (tensor<3x151x1xf32>, tensor) -> tensor<3x151x1xf32> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ARG11]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[3,151,1],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool true +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_1]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_9]], %[[VAL_8]] : (tensor<3x151x64xf32>, tensor<1x1x1xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reduce_sum %[[VAL_10]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.pow %[[VAL_11]], %[[VAL_13]] : (tensor<3x151x1xf32>, tensor<1x1x1xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[3,151,1],f32> +// CHECK: } func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) { %float2.000000e00 = torch.constant.float 2.000000e+00 %int-1 = torch.constant.int -1 @@ -407,13 +443,14 @@ func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_1]], %[[VAL_4]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 3.123400e+00 @@ -430,10 +467,12 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 @@ -444,19 +483,21 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- -// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$float_int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> -// CHECK: } -func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.rsub.Scalar$float_int(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 %alpha = torch.constant.int 1 %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> @@ -545,14 +586,19 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_1]], %[[VAL_13]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x4x1xf32>, tensor<1x1x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.rsqrt %[[VAL_19]] : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_18]], %[[VAL_20]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_21]], %[[VAL_16]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_23:.*]] = tosa.add %[[VAL_22]], %[[VAL_17]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[10,4,3],f32> // CHECK: } func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { %0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32> @@ -608,44 +654,46 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> -// CHECK-DAG: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> -// CHECK-DAG: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK-LABEL: func.func @torch.aten.native_layer_norm$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[VAL_7:.*]] = torch.constant.int 3 // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<1.200000e+01> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> -// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> -// CHECK: } -func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.mul %[[VAL_18]], %[[VAL_18]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_22]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = tosa.mul %[[VAL_23]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_27:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_27]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_30:.*]] = tosa.add %[[VAL_24]], %[[VAL_28]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_31:.*]] = tosa.rsqrt %[[VAL_30]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_32:.*]] = tosa.mul %[[VAL_29]], %[[VAL_31]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_32]], %[[VAL_25]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_34:.*]] = tosa.add %[[VAL_33]], %[[VAL_26]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_35:.*]] = torch_c.from_builtin_tensor %[[VAL_34]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> +// CHECK: return %[[VAL_35]] : !torch.vtensor<[5,2,2,3],f32> +// CHECK: } +func.func @torch.aten.native_layer_norm$basic(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { %float5.000000e-01 = torch.constant.float 5.000000e-01 %int3 = torch.constant.int 3 %int2 = torch.constant.int 2 @@ -1024,19 +1072,21 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // ----- // CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { -// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> -// CHECK: %[[INT4:.*]] = torch.constant.int 4 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32> -// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32> -// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor, tensor<3x5xf32>) -> tensor<3x5xi1> -// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> -// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64> -// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> -// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_8]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x5xf32>) -> tensor<3x5xi1> +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_6]], %[[VAL_5]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<3x5xf32>) -> tensor<3x5xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[3,5],si64> +// CHECK: } func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { %int4 = torch.constant.int 4 %false = torch.constant.bool false @@ -1049,25 +1099,26 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> ! // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32> -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> -// CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<1x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_11]], %[[VAL_17]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,4,2],f32> // CHECK: } func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { %int-1 = torch.constant.int -1 @@ -1080,15 +1131,16 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK-LABEL: func.func @torch.aten.add$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si64> { -// CHECK-DAG- %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,2],si64> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor<1x1xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<2x2xi32>) -> tensor<2x2xi64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,2],si64> // CHECK: } func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si64> { %int1 = torch.constant.int 1 @@ -1103,13 +1155,15 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_9]], %[[VAL_8]] : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { %int1 = torch.constant.int 1 @@ -1211,14 +1265,15 @@ func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> ! // CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_3]], %[[VAL_6]], %[[VAL_2]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_2]], %[[VAL_7]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { %int0 = torch.constant.int 0 @@ -1231,12 +1286,13 @@ func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>, // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> -// CHECK-DAG: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> -// CHECK-DAG: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_4]], %[[VAL_6]], %[[VAL_5]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> @@ -1261,12 +1317,13 @@ func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,12,5,5],f32>, // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> -// CHECK-DAG: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> -// CHECK-DAG: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_4]], %[[VAL_6]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,5,5],f32> // CHECK: } func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> @@ -1279,13 +1336,14 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %int2 = torch.constant.int 2 @@ -1295,26 +1353,28 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK-LABEL: func.func @torch.aten.isclose$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.float 1.000000e-08 // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool false -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_8:.*]] = tosa.abs %[[VAL_7]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_3]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_8]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,5],i1> -// CHECK: } -func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_11]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_9]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_10]], %[[VAL_14]] : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_16:.*]] = tosa.greater_equal %[[VAL_15]], %[[VAL_12]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,5],i1> +// CHECK: } +func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { %float1.000000e-08 = torch.constant.float 1.000000e-08 %float1.000000e-05 = torch.constant.float 1.000000e-05 %false = torch.constant.bool false @@ -1505,13 +1565,16 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_10:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_7]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor) -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_13]], %[[VAL_11]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_11]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.select %[[VAL_13]], %[[VAL_10]], %[[VAL_12]] : (tensor, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_15:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.floor %[[VAL_15]] : (tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_14]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %str = torch.constant.str "trunc" @@ -1573,17 +1636,19 @@ func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f // CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_11:.*]] = tosa.greater %[[VAL_8]], %[[VAL_10]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_12]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.logical_not %[[VAL_13]] : (tensor) -> tensor -// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_9]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_14]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_18:.*]] = tosa.cast %[[VAL_17]] : (tensor) -> tensor -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[?,?],si64> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.greater %[[VAL_11]], %[[VAL_12]] : (tensor<1x1xi32>, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = tosa.equal %[[VAL_14]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.logical_not %[[VAL_15]] : (tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_7]], %[[VAL_10]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_18:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_16]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = tosa.select %[[VAL_18]], %[[VAL_17]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = tosa.cast %[[VAL_19]] : (tensor) -> tensor +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?],si64> // CHECK: } func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { %str = torch.constant.str "floor" @@ -1679,15 +1744,18 @@ func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_6]] : (tensor<2x4xf32>, tensor) -> tensor<2x4xi1> -// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xi1>, tensor, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_11:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_12:.*]] = tosa.floor %[[VAL_11]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_10]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_2]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_3]], %[[VAL_14]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_16]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_10]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xi1> +// CHECK: %[[VAL_13:.*]] = tosa.select %[[VAL_12]], %[[VAL_9]], %[[VAL_11]] : (tensor<2x4xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_14:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_15:.*]] = tosa.floor %[[VAL_14]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_2]], %[[VAL_16]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_3]], %[[VAL_17]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> @@ -1743,9 +1811,10 @@ func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %float2.000000e00 = torch.constant.float 2.000000e+00 @@ -1790,10 +1859,11 @@ func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_5]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],si32> // CHECK: } func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %int2 = torch.constant.int 2 @@ -1824,10 +1894,11 @@ func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_4]], %[[VAL_1]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_1]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],i1> // CHECK: } func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { %int2 = torch.constant.int 2 @@ -1928,13 +1999,14 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,5,2],f32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<1x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[4,5,2],f32> // CHECK: } func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { %int2 = torch.constant.int 2 @@ -2004,20 +2076,22 @@ func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> -// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.floor %[[VAL_9]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_6]], %[[VAL_11]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_7]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_4]], %[[VAL_7]] : (tensor<1x1x1xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_12]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_16:.*]] = tosa.logical_or %[[VAL_14]], %[[VAL_15]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_6]], %[[VAL_8]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[3,4,5],f32> // CHECK: } func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -2109,13 +2183,14 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<3xi32>) -> tensor<24x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_19:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_18]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[10,8,6],f32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<1x3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_19]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[10,8,6],f32> // CHECK: } func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { %int1 = torch.constant.int 1 @@ -2140,13 +2215,14 @@ func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %ar // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x48x1xf32> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<6x1x2xi32>) -> tensor<6x2xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<2xi32>) -> tensor<6x2xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> -// CHECK: %[[VAL_17:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_16]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x48x1xf32>) -> tensor<6x8xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[6,8],f32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<1x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> +// CHECK: %[[VAL_18:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_17]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x48x1xf32>) -> tensor<6x8xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[6,8],f32> // CHECK: } func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { %int1 = torch.constant.int 1 @@ -2175,15 +2251,16 @@ func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32> // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<4xi32>) -> tensor<24x4xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_19]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> -// CHECK: return %[[VAL_24]] : !torch.vtensor<[2,3,4,4],f32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<4xi32>) -> tensor<1x4xi32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<1x4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_21:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_20]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[2,3,4,4],f32> // CHECK: } func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { %int0 = torch.constant.int 0 @@ -2196,29 +2273,30 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t // ----- // CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> -// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1xi32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> - +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_6]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_8]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,2],si64> +// CHECK: } func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> @@ -2236,9 +2314,10 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> -// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1xi64> +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_8]], %[[VAL_3]] : (tensor<4xi1>, tensor<1xi64>, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<4xi64> -> !torch.vtensor<[4],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[4],si64> // CHECK: } func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { %int1 = torch.constant.int 1 @@ -2313,14 +2392,15 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<25xf32>) -> tensor<1x25x1xf32> // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,3],f32> // CHECK: } func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { %none = torch.constant.none @@ -2414,17 +2494,23 @@ func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> // CHECK: %[[VAL_6:.*]] = torch.constant.none // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor -// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> -// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_8]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.minimum %[[VAL_10]], %[[VAL_9]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_14]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_17:.*]] = tosa.minimum %[[VAL_16]], %[[VAL_15]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_19]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_22:.*]] = tosa.minimum %[[VAL_21]], %[[VAL_20]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_12]], %[[VAL_18]], %[[VAL_23]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> // CHECK: } func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { %none = torch.constant.none @@ -2639,14 +2725,15 @@ func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<3xi32>) -> tensor<72x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,8,9],f64> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<1x3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> +// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> +// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> +// CHECK: return %[[VAL_23]] : !torch.vtensor<[1,1,8,9],f64> // CHECK: } func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { %float4.000000e00 = torch.constant.float 4.000000e+00 @@ -2676,14 +2763,15 @@ func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32> // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<3xi32>) -> tensor<14x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,2,7],f32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<1x3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,2,7],f32> // CHECK: } func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { %none = torch.constant.none @@ -2744,12 +2832,13 @@ func.func @torch.aten.exp$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -2763,12 +2852,13 @@ func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -2781,10 +2871,11 @@ func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_1]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -2798,10 +2889,11 @@ func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_2]], %[[VAL_3]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_2]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -2816,12 +2908,13 @@ func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 // CHECK: %[[VAL_3:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.log %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_3]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 @@ -2838,12 +2931,13 @@ func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_3]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_4]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_4]], %[[VAL_8]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.log %[[VAL_9]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.logit$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 @@ -2907,10 +3001,11 @@ func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> -// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_4]], %[[VAL_5]] : (tensor, tensor<4xf64>) -> tensor<4xi1> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4xi1> -> !torch.vtensor<[4],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4],i1> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_5]], %[[VAL_6]] : (tensor<1xf64>, tensor<4xf64>) -> tensor<4xi1> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4xi1> -> !torch.vtensor<[4],i1> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4],i1> // CHECK: } func.func @torch.aten.lt.Scalar$intfloat(%arg0: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { %float1.100000e00 = torch.constant.float 1.100000e+00 @@ -3014,16 +3109,17 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_9]], %[[VAL_11]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<1x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_15:.*]] = tosa.gather %[[VAL_8]], %[[VAL_14]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[3,4,2],f32> // CHECK: } func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { %int0 = torch.constant.int 0 @@ -3056,10 +3152,11 @@ func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch. // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -3073,10 +3170,11 @@ func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3092,9 +3190,9 @@ func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xi64>}> : () -> tensor<12xi64> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xindex>} : () -> !tosa.shape<12> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, tensor<12xi64>, tensor) -> tensor<1x1x20x20x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, !tosa.shape<12>, tensor) -> tensor<1x1x20x20x4x5xf32> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,1,20,20,4,5],f32> // CHECK: }