Skip to content

Commit 45408fa

Browse files
committed
OpenXLA-specific changes
1 parent 856ec67 commit 45408fa

File tree

45 files changed

+3556
-980
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+3556
-980
lines changed

BUILD

+853
Large diffs are not rendered by default.

include/triton/Conversion/MLIRTypes.h

+11-10
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,23 @@ inline Type u1Ty(MLIRContext *ctx) {
2121
}
2222

2323
// Float types
24-
inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
25-
inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
26-
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
27-
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
24+
inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); }
25+
inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
26+
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
27+
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

2929
inline bool isFloat(Type type) {
3030
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
32-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
33-
type.isFloat8E5M2FNUZ();
31+
type.isBF16() ||
32+
llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
33+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
34+
mlir::Float8E5M2FNUZType>(type);
3435
}
3536

3637
inline bool isFloat8(Type type) {
37-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
38-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
39-
type.isFloat8E5M2FNUZ();
38+
return llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
39+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
40+
mlir::Float8E5M2FNUZType>(type);
4041
}
4142

4243
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

lib/Analysis/AxisInfo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Analysis/Utility.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -732,14 +732,14 @@ bool supportMMA(triton::DotOp op, int version) {
732732
return false;
733733
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
734734
retShapePerCTA[rank - 1] % 8 == 0 &&
735-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
735+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy) ||
736736
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
737737
aElemTy.isF32()))) {
738738
return false;
739739
}
740740
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
741741
if (op.getMaxNumImpreciseAcc() < 32 &&
742-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
742+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy)) &&
743743
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
744744
return false;
745745
}
@@ -760,8 +760,9 @@ bool supportMMA(Value value, int version) {
760760
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
761761
// FP8 is not natively supported on all mma versions but it can always be
762762
// promoted to fp16 therefore we can always support it.
763-
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
764-
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
763+
bool isFP8 =
764+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
765+
mlir::Float8E5M2FNUZType, mlir::Float8E4M3FNUZType>(elemTy);
765766
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
766767
(elemTy.isF32() && version >= 2) ||
767768
(elemTy.isInteger(8) && version >= 2);

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/TritonGPU/IR/Ops.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ struct CanonicalizeConvertFromAlloc
150150
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
151151
if (!convert)
152152
return failure();
153+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
154+
// to SharedEncoding, so we want to keep this layout conversion.
155+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
156+
convert.getSrc().getType().getEncoding()))
157+
return failure();
153158
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
154159
op, op->getResult(0).getType(), convert.getSrc());
155160
return mlir::success();
@@ -212,13 +217,13 @@ struct CanonicalizeConvertFromConvert
212217
// heuristic to accommodate fused attention.
213218
auto srcType = op.getSrc().getType();
214219
auto dstType = op.getType();
215-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
216-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
220+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
221+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
217222
return failure();
218223

219224
// for hopper MMAv3
220-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
221-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
225+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
226+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
222227
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
223228
return dot->hasTrait<OpTrait::DotLike>();
224229
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+45-8
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace mlir {
2020
namespace triton {
2121
namespace gpu {
2222

23-
namespace {
24-
2523
// Get the highest version supported for the hardware and the dot.
2624
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2725
// List supported mma version in order of preference.
@@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4442
return 0;
4543
}
4644

47-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
48-
int numWarps) {
45+
SmallVector<unsigned>
46+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
4947
auto rank = shape.size();
5048
// Early exit for batched matmul
5149
if (rank == 3)
@@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
109107
}
110108

111109
SmallVector<unsigned, 2>
112-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
110+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
113111
const SmallVector<unsigned, 3> &instrShape) {
114112
SetVector<Operation *> slices;
115-
mlir::getForwardSlice(dotOp.getResult(), &slices);
113+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
116114
// Contains a chained dot. We prefer to assign warps to one axis
117115
// to facilitate use cases like flash attention, allowing reductions within
118116
// the same warp.
@@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
167165
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
168166
newLayout, SharedMemorySpace);
169167
rewriter.setInsertionPointAfterValue(arg);
168+
169+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
170+
// to SharedEncoding.
171+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
172+
argType.getEncoding())) {
173+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
174+
// then pass it to the LocalAllocOp.
175+
auto newArgType = RankedTensorType::get(
176+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
177+
auto dotOperandToBlockedCvt =
178+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
179+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
180+
dotOperandToBlockedCvt);
181+
}
182+
170183
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
171184
}
172185

173186
SmallVector<unsigned, 3>
174-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
187+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
175188
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
176189
switch (version) {
177190
case 2:
@@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
184197
}
185198
}
186199

200+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
201+
// extension.
202+
namespace {
203+
187204
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
188205
int computeCapability;
189206
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
190207

191208
static bool bwdFilter(Operation *op) {
209+
// Dot operand layout assignment to Predicates are not currently supported
210+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
211+
// condition limits visibility of the original bit-width so that predicate
212+
// are not considered, hence, kwidth can never be = 32.
213+
if (isa<arith::UIToFPOp>(op)) {
214+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
215+
if (srcType.isInteger(1))
216+
return false;
217+
}
192218
return op->getNumOperands() == 1 &&
193219
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
194220
isPureUnaryInlineAsm(op) ||
195221
op->getDialect()->getTypeID() ==
196222
mlir::TypeID::get<arith::ArithDialect>());
197223
}
198224

225+
public:
199226
// Finds the first different bitwidth in the chain of shape-preserving
200227
// unary ops that x depends on.
201228
// There are two primary scenarios:
@@ -343,7 +370,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
343370
NvidiaMmaEncodingAttr mmaLayout =
344371
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
345372
if (mmaLayout) {
346-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
373+
bool isNativeFP8 =
374+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
347375
// promote operands for sm < 89 since fp8 mma is not natively supported
348376
// promote operands for sm >= 90 when mma is not v3
349377
if (!isNativeFP8 ||
@@ -718,6 +746,15 @@ class TritonGPUAccelerateMatmulPass
718746
}
719747
};
720748

749+
// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
750+
int computeOrigBitWidth(Value x) {
751+
return BlockedToMMA::computeOrigBitWidth(x);
752+
}
753+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
754+
int opIdx, bool allowTranspose) {
755+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
756+
}
757+
721758
} // namespace gpu
722759
} // namespace triton
723760
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
131131

132132
Value zero = builder.createWithStage<arith::ConstantIntOp>(
133133
forOp.getLoc(), stage, clusterId, 0, 32);
134+
134135
// Replace the load with insert/extract slice.
135136
builder.setInsertionPoint(loadOp);
136137
Location loc = loadOp.getLoc();
@@ -491,7 +492,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
491492
});
492493

493494
bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp>(op);
494-
loadsToPipeline.insert(&op);
495+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
496+
// loadsToPipeline.insert(&op);
495497
LoadInfo loadInfo;
496498
for (auto use : users) {
497499
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -527,6 +529,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
527529
getBlockedEncoding(loadOp, axisInfoAnalysis);
528530
}
529531
}
532+
533+
// TODO: b/381421713 - Remove this once pipelining is fixed.
534+
if (!loadInfo.sharedEncoding) continue;
535+
loadsToPipeline.insert(&op);
536+
530537
loadToInfo[&op] = loadInfo;
531538
}
532539
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::gpu::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
141141
type.getMutableMemory(), type.getAllocShape()),
142142
v, offsetsVal);
143143

144+
// We need to assign kwidth to zero in the case where the parent layout is
145+
// Blocked, otherwise the verifier emits a failure. The parent layout is
146+
// Blocked only when Tensor Cores are disabled.
147+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
148+
? 0
149+
: prefetchWidth / 8;
144150
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
145-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
151+
builder.getContext(), opIdx, dotEncoding, kwidth);
146152
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
147153
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
148154
newSmem);
@@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
191197
break;
192198
if (!op->getResult(0).hasOneUse())
193199
break;
200+
// Similar to issues faced in HoistLayoutConversion pattern in
201+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
202+
// predicates as they aren't supported in Triton when encoded with dot_op
203+
// layout.
204+
if (isa<arith::UIToFPOp>(op)) {
205+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
206+
if (srcType.isInteger(1))
207+
break;
208+
}
209+
// Propagation through ExpandDims is currently not supported. This blindly
210+
// replaces the encoding with dot encoding & but ExpandDims requires a
211+
// SliceEncoding. This could be rewritten to support it somehow, but I
212+
// don't think it's trivial & it's currently crashing.
213+
if (isa<ExpandDimsOp>(op)) {
214+
break;
215+
}
194216
rets.push_back(op->getOperand(0));
195217
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
196218
foundConvertFromShared = true;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

+22-14
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
4444
SmallVector<unsigned> validN;
4545

4646
// MMAv3 with larger instruction shape is preferred.
47-
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
48-
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
49-
eltType.isF32()) {
47+
if (llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
48+
mlir::Float8E4M3FNUZType>(eltType) ||
49+
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
5050
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
5151
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
5252
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
@@ -994,18 +994,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
994994
} else {
995995
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
996996
return std::nullopt;
997-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
998-
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
999-
.getEncoding());
1000-
if (!dotOpEnc)
997+
auto enc =
998+
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
999+
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
1000+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1001+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1002+
auto order = ttg::getOrder(srcTy.getEncoding());
1003+
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1004+
tempAttr = ttg::SharedEncodingAttr::get(
1005+
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
1006+
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
1007+
} else if (enc.getAbstractAttribute().getName().str() ==
1008+
"triton.gpu.sparse_dot_meta_encoding") {
1009+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1010+
tempAttr = ttg::SharedEncodingAttr::get(
1011+
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
1012+
ttg::getOrder(srcTy.getEncoding()),
1013+
ttg::getCTALayout(srcTy.getEncoding()));
1014+
} else {
10011015
return std::nullopt;
1002-
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1003-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1004-
auto order = ttg::getOrder(srcTy.getEncoding());
1005-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1006-
tempAttr = ttg::SharedEncodingAttr::get(
1007-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
1008-
bitWidth, /*needTrans=*/false);
1016+
}
10091017
}
10101018
// Check that the shared encodings needed by the users are compatible.
10111019
if (attr != nullptr && attr != tempAttr) {

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
7777
const auto &d = getD();
7878
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
7979
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
80-
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
81-
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
80+
bool isFP8 =
81+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
82+
mlir::Float8E5M2FNUZType, mlir::Float8E4M3FNUZType>(aElTy);
8283
bool accFP32 =
8384
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
8485
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct FenceInsertionPass
4444
return;
4545
ModuleOp mod = getOperation();
4646
mod.walk([&](Operation *op) {
47-
if (!isa<ttng::WarpGroupDotOp>(op))
47+
if (!op->hasTrait<OpTrait::DotLike>())
4848
return WalkResult::advance();
4949
OpBuilder builder(op);
5050
auto a = op->getOperand(0);

0 commit comments

Comments
 (0)