@@ -19,8 +19,6 @@ namespace mlir {
19
19
namespace triton {
20
20
namespace gpu {
21
21
22
- namespace {
23
-
24
22
// Get the highest version supported for the hardware and the dot.
25
23
static int getMMAVersionSafe (int computeCapability, DotOp op) {
26
24
// List supported mma version in order of preference.
@@ -43,8 +41,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
43
41
return 0 ;
44
42
}
45
43
46
- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
47
- int numWarps) {
44
+ SmallVector<unsigned >
45
+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
48
46
auto rank = shape.size ();
49
47
// Early exit for batched matmul
50
48
if (rank == 3 )
@@ -57,9 +55,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
57
55
auto slices = multiRootGetSlice (dotOp, {filter}, {filter});
58
56
bool hasChainedDot = false ;
59
57
for (Operation *op : slices) {
60
- if (isa<DotOp>(op) && (op != dotOp)) {
61
- auto chainedDot = cast<DotOp>(op);
62
- auto resTy = chainedDot.getResult ().getType ();
58
+ if (dotOp->getName () == op->getName () && op != dotOp) {
59
+ auto resTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
63
60
if (resTy.getRank () != rank) {
64
61
continue ;
65
62
}
@@ -108,12 +105,17 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
108
105
}
109
106
110
107
SmallVector<unsigned , 2 >
111
- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
108
+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
112
109
const SmallVector<unsigned , 3 > &instrShape) {
113
110
SetVector<Operation *> slices;
114
- mlir::getForwardSlice (dotOp.getResult (), &slices);
115
- if (llvm::find_if (slices, [](Operation *op) { return isa<DotOp>(op); }) !=
116
- slices.end ())
111
+ mlir::getForwardSlice (dotOp->getResult (0 ), &slices);
112
+ if (llvm::find_if (slices, [&](Operation *op) {
113
+ return dotOp->getName () == op->getName () ||
114
+ // Contains a chained dot. We prefer to assign warps to one axis
115
+ // to facilitate use cases like flash attention, allowing reductions
116
+ // within the same warp.
117
+ op->hasTrait <OpTrait::DotLike>();
118
+ }) != slices.end ())
117
119
return {(unsigned )numWarps, 1 };
118
120
119
121
// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
@@ -162,11 +164,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
162
164
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
163
165
newLayout, SharedMemorySpace);
164
166
rewriter.setInsertionPointAfterValue (arg);
167
+
168
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
169
+ // to SharedEncoding.
170
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
171
+ argType.getEncoding ())) {
172
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
173
+ // then pass it to the LocalAllocOp.
174
+ auto newArgType = RankedTensorType::get (
175
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
176
+ auto dotOperandToBlockedCvt =
177
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
178
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
179
+ dotOperandToBlockedCvt);
180
+ }
181
+
165
182
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
166
183
}
167
184
168
185
SmallVector<unsigned , 3 >
169
- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
186
+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
170
187
int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
171
188
switch (version) {
172
189
case 2 :
@@ -179,18 +196,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
179
196
}
180
197
}
181
198
199
+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
200
+ // extension.
201
+ namespace {
202
+
182
203
class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
183
204
int computeCapability;
184
205
mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
185
206
186
207
static bool bwdFilter (Operation *op) {
208
+ // Dot operand layout assignment to Predicates are not currently supported
209
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
210
+ // condition limits visibility of the original bit-width so that predicate
211
+ // are not considered, hence, kwidth can never be = 32.
212
+ if (isa<arith::UIToFPOp>(op)) {
213
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
214
+ if (srcType.isInteger (1 ))
215
+ return false ;
216
+ }
187
217
return op->getNumOperands () == 1 &&
188
218
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
189
219
isPureUnaryInlineAsm (op) ||
190
220
op->getDialect ()->getTypeID () ==
191
221
mlir::TypeID::get<arith::ArithDialect>());
192
222
}
193
223
224
+ public:
194
225
// Finds the first different bitwidth in the chain of shape-preserving
195
226
// unary ops that x depends on.
196
227
// There are two primary scenarios:
@@ -595,6 +626,15 @@ class TritonGPUAccelerateMatmulPass
595
626
}
596
627
};
597
628
629
+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
630
+ int computeOrigBitWidth (Value x) {
631
+ return BlockedToMMA::computeOrigBitWidth (x);
632
+ }
633
+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
634
+ int opIdx, bool allowTranspose) {
635
+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
636
+ }
637
+
598
638
} // namespace gpu
599
639
} // namespace triton
600
640
} // namespace mlir
0 commit comments