@@ -20,8 +20,6 @@ namespace mlir {
20
20
namespace triton {
21
21
namespace gpu {
22
22
23
- namespace {
24
-
25
23
// Get the highest version supported for the hardware and the dot.
26
24
static int getMMAVersionSafe (int computeCapability, DotOp op) {
27
25
// List supported mma version in order of preference.
@@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
44
42
return 0 ;
45
43
}
46
44
47
- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
48
- int numWarps) {
45
+ SmallVector<unsigned >
46
+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
49
47
auto rank = shape.size ();
50
48
// Early exit for batched matmul
51
49
if (rank == 3 )
@@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
109
107
}
110
108
111
109
SmallVector<unsigned , 2 >
112
- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
110
+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
113
111
const SmallVector<unsigned , 3 > &instrShape) {
114
112
SetVector<Operation *> slices;
115
- mlir::getForwardSlice (dotOp. getResult (), &slices);
113
+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
116
114
// Contains a chained dot. We prefer to assign warps to one axis
117
115
// to facilitate use cases like flash attention, allowing reductions within
118
116
// the same warp.
@@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
167
165
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
168
166
newLayout, SharedMemorySpace);
169
167
rewriter.setInsertionPointAfterValue (arg);
168
+
169
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
170
+ // to SharedEncoding.
171
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
172
+ argType.getEncoding ())) {
173
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
174
+ // then pass it to the LocalAllocOp.
175
+ auto newArgType = RankedTensorType::get (
176
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
177
+ auto dotOperandToBlockedCvt =
178
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
179
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
180
+ dotOperandToBlockedCvt);
181
+ }
182
+
170
183
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
171
184
}
172
185
173
186
SmallVector<unsigned , 3 >
174
- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
187
+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
175
188
int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
176
189
switch (version) {
177
190
case 2 :
@@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
184
197
}
185
198
}
186
199
200
+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
201
+ // extension.
202
+ namespace {
203
+
187
204
class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
188
205
int computeCapability;
189
206
mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
190
207
191
208
static bool bwdFilter (Operation *op) {
209
+ // Dot operand layout assignment to Predicates are not currently supported
210
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
211
+ // condition limits visibility of the original bit-width so that predicate
212
+ // are not considered, hence, kwidth can never be = 32.
213
+ if (isa<arith::UIToFPOp>(op)) {
214
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
215
+ if (srcType.isInteger (1 ))
216
+ return false ;
217
+ }
192
218
return op->getNumOperands () == 1 &&
193
219
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
194
220
isPureUnaryInlineAsm (op) ||
195
221
op->getDialect ()->getTypeID () ==
196
222
mlir::TypeID::get<arith::ArithDialect>());
197
223
}
198
224
225
+ public:
199
226
// Finds the first different bitwidth in the chain of shape-preserving
200
227
// unary ops that x depends on.
201
228
// There are two primary scenarios:
@@ -805,6 +832,15 @@ class TritonGPUAccelerateMatmulPass
805
832
}
806
833
};
807
834
835
+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
836
+ int computeOrigBitWidth (Value x) {
837
+ return BlockedToMMA::computeOrigBitWidth (x);
838
+ }
839
+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
840
+ int opIdx, bool allowTranspose) {
841
+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
842
+ }
843
+
808
844
} // namespace gpu
809
845
} // namespace triton
810
846
} // namespace mlir
0 commit comments