Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Update llvm-project to 5d6d982df61d16b6d498e6d59dd91c059679d3d8
Update stablehlo to b62dc66da9946b4c400c0d99c9d5bb8e04edaee6

Co-authored-by: Justin Ngo <[email protected]>

---------

Signed-off-by: Justin Ngo <[email protected]>
Signed-off-by: Praveen G <[email protected]>
Co-authored-by: Justin Ngo <[email protected]>
  • Loading branch information
praveen-g-ctt and justin-ngo-arm authored Jan 31, 2025
1 parent 4b9b972 commit 1690320
Show file tree
Hide file tree
Showing 8 changed files with 938 additions and 540 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 4066 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 52 files
+1 −1 BUILD.bazel
+2 −2 WORKSPACE.bazel
+1 −1 build_tools/llvm_version.txt
+17 −1 docs/awesome.md
+3 −16 stablehlo/conversions/linalg/transforms/Rewriters.h
+39 −38 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+23 −0 stablehlo/dialect/AssemblyFormat.cpp
+59 −0 stablehlo/dialect/AssemblyFormat.h
+17 −0 stablehlo/dialect/Base.cpp
+3 −0 stablehlo/dialect/Base.h
+17 −0 stablehlo/dialect/Base.td
+1 −1 stablehlo/dialect/CMakeLists.txt
+15 −0 stablehlo/dialect/StablehloAttrs.td
+79 −7 stablehlo/dialect/StablehloBytecode.cpp
+23 −0 stablehlo/dialect/StablehloEnums.td
+38 −0 stablehlo/dialect/StablehloOps.cpp
+19 −2 stablehlo/dialect/StablehloOps.td
+29 −4 stablehlo/dialect/TypeInference.cpp
+9 −0 stablehlo/dialect/TypeInference.h
+3 −3 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+24 −11 stablehlo/dialect/VhloAttrs.td
+74 −1 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+33 −1 stablehlo/dialect/VhloEnums.td
+9 −8 stablehlo/dialect/VhloOps.cpp
+8 −1 stablehlo/dialect/VhloOps.td
+67 −0 stablehlo/integrations/c/StablehloAttributes.cpp
+37 −0 stablehlo/integrations/c/StablehloAttributes.h
+44 −0 stablehlo/integrations/python/StablehloModule.cpp
+21 −0 stablehlo/integrations/python/tests/stablehlo.py
+6 −7 stablehlo/reference/Types.cpp
+40 −0 stablehlo/tests/ops_stablehlo.mlir
+63 −0 stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
+5 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+13 −0 stablehlo/tests/print_stablehlo.mlir
+11 −0 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+2,966 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc
+31 −1 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+26 −0 stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
+24 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir
+22 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir
+1 −1 stablehlo/transforms/MapStablehloToVhlo.h
+3 −3 stablehlo/transforms/PassUtils.h
+5 −0 stablehlo/transforms/Passes.h
+20 −2 stablehlo/transforms/StablehloAggressiveSimplification.cpp
+6 −3 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td
+24 −0 stablehlo/transforms/StablehloLegalizeToVhlo.cpp
+24 −0 stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+53 −0 stablehlo/transforms/VhloToVersion.cpp
+16 −0 stablehlo/transforms/VhloToVersionPatterns.td
5 changes: 0 additions & 5 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType);

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir

Expand Down
518 changes: 395 additions & 123 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp

Large diffs are not rendered by default.

59 changes: 49 additions & 10 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,

// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
Expand All @@ -378,13 +378,18 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
if (!flattenedCoeffValue)
return std::nullopt;

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
flattenedCoeffValue.value())
.failed())
return std::nullopt;

// Multiply the coefficients by the coordinates
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
// tensor<3xi32>) -> tensor<8x3xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);

// Sum up the products of the coefficients and coordinates
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
Expand Down Expand Up @@ -616,7 +621,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
// -> tensor<3x2xi32>
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
Expand All @@ -643,14 +648,19 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
if (!flattenedCoeffValue)
return std::nullopt;

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
flattenedCoeffValue.value())
.failed())
return std::nullopt;

// Multiply the coefficients by the coordinates.
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
// tensor<2xi32>) -> tensor<3x2xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);

// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
Expand Down Expand Up @@ -734,10 +744,20 @@ std::optional<Value> convertReduceOpCommon(
RankedTensorType reduce_type =
RankedTensorType::get(shape_vec, reduce_element_type);

auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
Value reduce_op;
if constexpr (std::is_same<T, tosa::ReduceMinOp>() ||
std::is_same<T, tosa::ReduceMaxOp>()) {
// Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min
// and tosa.reduce_max
reduce_op = CreateOpAndInfer<T>(
rewriter, op->getLoc(), reduce_type, val, axis_attr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
} else {
reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
}

val = reduce_op.getResult();
val = reduce_op;
}

if (is_quantized) {
Expand Down Expand Up @@ -973,6 +993,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,

if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), val.value(),
div_const)
.failed())
return std::nullopt;

return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.value(), div_const, 0)
.getResult();
Expand Down Expand Up @@ -1021,6 +1047,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
}

Value ordValRank0 = ordVal;
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input_value, ordVal)
.failed())
return std::nullopt;

if (fabs(ordLiteralFloat) < epsilon ||
fabs(static_cast<double>(ordLiteralInt)) < epsilon) {
op->emitOpError("unimplemented: L0 norm");
Expand Down Expand Up @@ -1049,9 +1080,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
rewriter, op, output_type, powVal, axes_elems, keep_dims);
if (!result)
return std::nullopt;
auto reciprocalVal = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), ordVal.getType(), ordVal)
.getResult();

Value reciprocalVal =
CreateOpAndInfer<tosa::ReciprocalOp>(rewriter, op->getLoc(),
ordValRank0.getType(), ordValRank0)
.getResult();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), result.value(),
reciprocalVal)
.failed())
return std::nullopt;

return CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(), output_type,
result.value(), reciprocalVal)
.getResult();
Expand Down
52 changes: 23 additions & 29 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -301,31 +302,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
(isa<Float8E5M2Type>(src) && dest.isF16())) {
return success();
}
// clang-format on
Expand Down Expand Up @@ -393,6 +394,11 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
auto zeroValue =
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return rewriter.notifyMatchFailure(
op, "Failed to equalize ranks among operands and result");

auto boolType = srcType.clone(rewriter.getIntegerType(1));
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, src);
Expand Down Expand Up @@ -488,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
outputElemTy.isF16()) ||
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
outputElemTy.isF16())) {
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
Expand All @@ -500,17 +506,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
return success();
}

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

} // namespace tosa
} // namespace mlir
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getF32Type();
if (isa<Float64Type>(inputType))
return rewriter.getF64Type();
if (inputType.isFloat8E5M2())
if (isa<Float8E5M2Type>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FN())
if (isa<Float8E4M3FNType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isInteger(8))
// this is an intentional deviation from CUDA (which accumulates i8 to i64)
Expand Down
Loading

0 comments on commit 1690320

Please sign in to comment.