Skip to content

Commit 3946bb1

Browse files
committed
1 parent 4b9b972 commit 3946bb1

File tree

7 files changed

+150
-108
lines changed

7 files changed

+150
-108
lines changed

externals/llvm-project

Submodule llvm-project updated 4692 files

externals/stablehlo

Submodule stablehlo updated 52 files

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4545
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
4646
float val);
4747

48+
// Create a 8-bit int constant operator from a int
49+
Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op,
50+
int32_t val);
51+
4852
// Create a zero constant tensor of the desired type and shape.
4953
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
5054
Operation *op, Type type);
@@ -127,11 +131,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
127131
RankedTensorType weightTy,
128132
RankedTensorType outputTy, TypeAttr &accType);
129133

130-
// Temporary function to get TOSA const shape
131-
// TODO: Remove this function when getTosaConstShape is available in
132-
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
133-
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
134-
llvm::ArrayRef<int64_t> shape);
135134
} // namespace tosa
136135
} // namespace mlir
137136

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+102-63
Large diffs are not rendered by default.

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
119119
int32_t shift) {
120120
lhs = promoteType(rewriter, lhs, outType);
121121
rhs = promoteType(rewriter, rhs, outType);
122-
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
123-
lhs, rhs, shift);
122+
return tosa::CreateOpAndInfer<tosa::MulOp>(
123+
rewriter, op->getLoc(), outType, lhs, rhs,
124+
getTosaConstTensorSingleI8(rewriter, op, shift));
124125
}
125126

126127
template <>
@@ -384,7 +385,8 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
384385
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
385386
rewriter, op->getLoc(),
386387
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
387-
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
388+
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(),
389+
getTosaConstTensorSingleI8(rewriter, op, 0));
388390

389391
// Sum up the products of the coefficients and coordinates
390392
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
@@ -650,7 +652,8 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
650652
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
651653
rewriter, op->getLoc(),
652654
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
653-
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
655+
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(),
656+
getTosaConstTensorSingleI8(rewriter, op, 0));
654657

655658
// Sum up the products of the coefficients and coordinates
656659
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
@@ -973,8 +976,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
973976

974977
if (!input_is_qtype) {
975978
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
976-
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
977-
val.value(), div_const, 0)
979+
return CreateOpAndInfer<tosa::MulOp>(
980+
rewriter, op->getLoc(), output_type, val.value(), div_const,
981+
getTosaConstTensorSingleI8(rewriter, op, 0))
978982
.getResult();
979983
}
980984

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+28-28
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
150150
return const_op.getResult();
151151
}
152152

153+
// Create a 8-bit int constant operator from a int
154+
Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op,
155+
int32_t val) {
156+
auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
157+
auto shiftType = RankedTensorType::get({1}, shiftElementType);
158+
auto shiftZeroAttr = DenseElementsAttr::get(
159+
shiftType, rewriter.getIntegerAttr(shiftElementType, val));
160+
Value constVal =
161+
rewriter.create<tosa::ConstOp>(op->getLoc(), shiftType, shiftZeroAttr);
162+
return constVal;
163+
}
164+
153165
// Create a zero constant tensor of the desired type and shape.
154166
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
155167
Operation *op, Type type) {
@@ -301,31 +313,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
301313
(src.isF32() && dest.isInteger(8)) ||
302314
(src.isF32() && dest.isBF16()) ||
303315
(src.isF32() && dest.isF16()) ||
304-
(src.isF32() && dest.isFloat8E4M3()) ||
305-
(src.isF32() && dest.isFloat8E5M2()) ||
316+
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
317+
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
306318
// f16 -> *
307319
(src.isF16() && dest.isInteger(32)) ||
308320
(src.isF16() && dest.isInteger(16)) ||
309321
(src.isF16() && dest.isInteger(8)) ||
310322
(src.isF16() && dest.isBF16()) ||
311323
(src.isF16() && dest.isF32()) ||
312-
(src.isF16() && dest.isFloat8E4M3()) ||
313-
(src.isF16() && dest.isFloat8E5M2()) ||
324+
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
325+
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
314326
// bf16 -> *
315327
(src.isBF16() && dest.isInteger(32)) ||
316328
(src.isBF16() && dest.isInteger(16)) ||
317329
(src.isBF16() && dest.isInteger(8)) ||
318330
(src.isBF16() && dest.isF32()) ||
319-
(src.isBF16() && dest.isFloat8E4M3()) ||
320-
(src.isBF16() && dest.isFloat8E5M2()) ||
331+
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
332+
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
321333
// fp8e4m3 -> *
322-
(src.isFloat8E4M3() && dest.isBF16()) ||
323-
(src.isFloat8E4M3() && dest.isF32()) ||
324-
(src.isFloat8E4M3() && dest.isF16()) ||
334+
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
335+
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
336+
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
325337
// fp8e5m2 -> *
326-
(src.isFloat8E5M2() && dest.isBF16()) ||
327-
(src.isFloat8E5M2() && dest.isF32()) ||
328-
(src.isFloat8E5M2() && dest.isF16())) {
338+
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
339+
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
340+
(isa<Float8E5M2Type>(src) && dest.isF16())) {
329341
return success();
330342
}
331343
// clang-format on
@@ -488,10 +500,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
488500
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
489501
outputElemTy.isInteger(48)) {
490502
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
491-
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
492-
outputElemTy.isF16()) ||
493-
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
494-
outputElemTy.isF16())) {
503+
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
504+
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
505+
(isa<Float8E5M2Type>(inputElemTy) &&
506+
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
495507
accType = mlir::TypeAttr::get(rewriter.getF16Type());
496508
} else {
497509
accType = mlir::TypeAttr::get(outputElemTy);
@@ -500,17 +512,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
500512
return success();
501513
}
502514

503-
// Temporary function to get TOSA const shape
504-
// TODO: Remove this function when getTosaConstShape is available in
505-
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
506-
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
507-
llvm::ArrayRef<int64_t> shape) {
508-
auto attr = rewriter.getIndexTensorAttr(shape);
509-
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
510-
mlir::Operation *mlir_op =
511-
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
512-
return mlir_op->getResult(0);
513-
}
514-
515515
} // namespace tosa
516516
} // namespace mlir

lib/Dialect/Torch/Utils/Utils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
652652
return rewriter.getF32Type();
653653
if (isa<Float64Type>(inputType))
654654
return rewriter.getF64Type();
655-
if (inputType.isFloat8E5M2())
655+
if (isa<Float8E5M2Type>(inputType))
656656
return rewriter.getF32Type();
657-
if (inputType.isFloat8E4M3FN())
657+
if (isa<Float8E4M3FNType>(inputType))
658658
return rewriter.getF32Type();
659-
if (inputType.isFloat8E5M2FNUZ())
659+
if (isa<Float8E5M2FNUZType>(inputType))
660660
return rewriter.getF32Type();
661-
if (inputType.isFloat8E4M3FNUZ())
661+
if (isa<Float8E4M3FNUZType>(inputType))
662662
return rewriter.getF32Type();
663663
if (inputType.isInteger(8))
664664
// this is an intentional deviation from CUDA (which accumulates i8 to i64)

0 commit comments

Comments
 (0)