Skip to content

Commit 2aa35b6

Browse files
committed
OpenXLA-specific changes
1 parent 88c704e commit 2aa35b6

File tree

44 files changed

+3772
-1099
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3772
-1099
lines changed

BUILD

+931
Large diffs are not rendered by default.

include/triton/Tools/LinearLayout.h

+7
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,13 @@ class LinearLayout {
681681
// (i.e. every input bit affects the output).
682682
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
683683

684+
// Increase an input dimension without affecting the output dimension. The
685+
// added free variables are mapped to 0, ensuring that the new input
686+
// dimensions correspond directly to the existing output space. The function
687+
// errors out if `newInDimSize` is less than the current size or the new size
688+
// is not a power of 2.
689+
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
690+
684691
std::string toString() const;
685692

686693
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/AxisInfo.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1011,6 +1011,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10111011
CastOpAxisInfoVisitor<arith::ExtUIOp>,
10121012
CastOpAxisInfoVisitor<arith::TruncIOp>,
10131013
CastOpAxisInfoVisitor<arith::IndexCastOp>,
1014+
CastOpAxisInfoVisitor<arith::IndexCastUIOp>,
10141015
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10151016
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10161017
CastOpAxisInfoVisitor<triton::BitcastOp>>();

lib/Analysis/Utility.cpp

+36-2
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,42 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
683683
StringAttr kLane = StringAttr::get(ctx, "lane");
684684
StringAttr kWarp = StringAttr::get(ctx, "warp");
685685
StringAttr kBlock = StringAttr::get(ctx, "block");
686-
687-
auto comp = dstLayout->invertAndCompose(*srcLayout);
686+
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
687+
auto numDstRegs = dstLayout->getInDimSize(kRegister);
688+
// The `invertAndCompose` function will generate a layout that is injective
689+
// by assigning new output dimensions to free variables. For instance,
690+
// consider a scenario where `srcLayout` has a free variable in the lane
691+
// dimension, while `dstLayout` has two free variables in the lane
692+
// dimension and also a larger number of registers.
693+
// The injective form of `srcLayout` will add only a single additional row
694+
// to the transformation matrix, whereas the injective form of `dstLayout`
695+
// will add two additional rows. This discrepancy causes misleading results
696+
// because the matrices end up with a different number of rows.
697+
//
698+
// Take `dstLayout ⋅ srcLayout^-1` as an example:
699+
//
700+
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
701+
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
702+
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
703+
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
704+
// 1] → [n + 2, n + 1]
705+
//
706+
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
707+
// variable in registers, and the `(n + 2)`-th row represents the free
708+
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
709+
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
710+
// in two layouts do not correspond to the same free variable.
711+
//
712+
// To address this issue, we pad the free variables in `srcLayout` and
713+
// `dstLayout` to ensure they have the same number of registers. This
714+
// guarantees that the resulting matrices have the same number of rows,
715+
// ensuring consistency in the composition process.
716+
auto numRegs = std::max(numSrcRegs, numDstRegs);
717+
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
718+
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
719+
// comp describes the layout function to create dst from src.
720+
LinearLayout comp =
721+
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
688722
// We try to quotient by the largest subspace first
689723
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
690724
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
315315
// TODO(Keren): implement warp shuffle instead of using the general
316316
// approach that uses shared memory
317317
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
318-
} else if (llvm::is_contained(dims, kRegister)) {
318+
} else if (llvm::is_contained(dims, kRegister) ||
319+
dstLayout.getInDimSize(kRegister) !=
320+
srcLayout.getInDimSize(kRegister)) {
319321
// Case 4. Transfer between values in the same thread, in which case we
320322
// simply reorder the elements of adaptor.getSrc().
321-
return transferWithinThread(op, *conversion, adaptor, rewriter);
323+
return transferWithinThread(
324+
op, dstLayout.getFreeVariableMasks()[kRegister],
325+
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
322326
} else {
323327
// Cast 5. The two layouts are equivalent. We should probably remove
324328
// these in RemoveLayoutConversion.
@@ -328,8 +332,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328332
}
329333

330334
LogicalResult
331-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
332-
OpAdaptor adaptor,
335+
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
336+
const LinearLayout &conversion, OpAdaptor adaptor,
333337
ConversionPatternRewriter &rewriter) const {
334338
MLIRContext *ctx = op.getContext();
335339
auto loc = op.getLoc();
@@ -339,9 +343,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
339343
auto srcTy = op.getSrc().getType();
340344
auto dstTy = op.getType();
341345
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
342-
SmallVector<Value> outVals(conversion.getInDimSize(kRegister));
343-
for (int i = 0; i < outVals.size(); i++) {
344-
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
346+
SmallVector<Value> outVals(numRegs);
347+
for (int i = 0; i < numRegs; i++) {
348+
// Remove free masks from the register index
349+
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350+
// 0b00011. It means that register 7 (0b111) has the same value as
351+
// register 3 (0b011).
352+
auto idx = i & (~regMasks);
353+
auto srcIdx = conversion.hasInDim(kRegister)
354+
? conversion.apply({{kRegister, idx}}).begin()->second
355+
: idx;
345356
outVals[i] = inVals[srcIdx];
346357
}
347358
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/TritonGPU/IR/Dialect.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -3127,6 +3127,11 @@ struct CanonicalizeConvertFromAlloc
31273127
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
31283128
if (!convert)
31293129
return failure();
3130+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
3131+
// to SharedEncoding, so we want to keep this layout conversion.
3132+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
3133+
convert.getSrc().getType().getEncoding()))
3134+
return failure();
31303135
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
31313136
op, op->getResult(0).getType(), convert.getSrc());
31323137
return mlir::success();
@@ -3189,13 +3194,13 @@ struct CanonicalizeConvertFromConvert
31893194
// heuristic to accommodate fused attention.
31903195
auto srcType = op.getSrc().getType();
31913196
auto dstType = op.getType();
3192-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
3193-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
3197+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
3198+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
31943199
return failure();
31953200

31963201
// for hopper MMAv3
3197-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
3198-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
3202+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
3203+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
31993204
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
32003205
return dot->hasTrait<OpTrait::DotLike>();
32013206
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+43-7
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace mlir {
2020
namespace triton {
2121
namespace gpu {
2222

23-
namespace {
24-
2523
// Get the highest version supported for the hardware and the dot.
2624
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2725
// List supported mma version in order of preference.
@@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4442
return 0;
4543
}
4644

47-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
48-
int numWarps) {
45+
SmallVector<unsigned>
46+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
4947
auto rank = shape.size();
5048
// Early exit for batched matmul
5149
if (rank == 3)
@@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
109107
}
110108

111109
SmallVector<unsigned, 2>
112-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
110+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
113111
const SmallVector<unsigned, 3> &instrShape) {
114112
SetVector<Operation *> slices;
115-
mlir::getForwardSlice(dotOp.getResult(), &slices);
113+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
116114
// Contains a chained dot. We prefer to assign warps to one axis
117115
// to facilitate use cases like flash attention, allowing reductions within
118116
// the same warp.
@@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
167165
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
168166
newLayout, SharedMemorySpace);
169167
rewriter.setInsertionPointAfterValue(arg);
168+
169+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
170+
// to SharedEncoding.
171+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
172+
argType.getEncoding())) {
173+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
174+
// then pass it to the LocalAllocOp.
175+
auto newArgType = RankedTensorType::get(
176+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
177+
auto dotOperandToBlockedCvt =
178+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
179+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
180+
dotOperandToBlockedCvt);
181+
}
182+
170183
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
171184
}
172185

173186
SmallVector<unsigned, 3>
174-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
187+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
175188
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
176189
switch (version) {
177190
case 2:
@@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
184197
}
185198
}
186199

200+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
201+
// extension.
202+
namespace {
203+
187204
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
188205
int computeCapability;
189206
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
190207

191208
static bool bwdFilter(Operation *op) {
209+
// Dot operand layout assignment to Predicates are not currently supported
210+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
211+
// condition limits visibility of the original bit-width so that predicate
212+
// are not considered, hence, kwidth can never be = 32.
213+
if (isa<arith::UIToFPOp>(op)) {
214+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
215+
if (srcType.isInteger(1))
216+
return false;
217+
}
192218
return op->getNumOperands() == 1 &&
193219
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
194220
isPureUnaryInlineAsm(op) ||
195221
op->getDialect()->getTypeID() ==
196222
mlir::TypeID::get<arith::ArithDialect>());
197223
}
198224

225+
public:
199226
// Finds the first different bitwidth in the chain of shape-preserving
200227
// unary ops that x depends on.
201228
// There are two primary scenarios:
@@ -806,6 +833,15 @@ class TritonGPUAccelerateMatmulPass
806833
}
807834
};
808835

836+
// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
837+
int computeOrigBitWidth(Value x) {
838+
return BlockedToMMA::computeOrigBitWidth(x);
839+
}
840+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
841+
int opIdx, bool allowTranspose) {
842+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
843+
}
844+
809845
} // namespace gpu
810846
} // namespace triton
811847
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace {
1717
// dot(a, b, inputPrecision="tf32x3") ->
1818
// let aBig = f32ToTF32(a), aSmall = a - aBig;
1919
// let bBig = f32ToTF32(b), bSmall = b - bBig;
20-
// dot(aSmall, bBig, inputPrecision="tf32") +
21-
// dot(aBig, bSmall, inputPrecision="tf32") +
22-
// dot(aBig, bBig, inputPrecision="tf32")
20+
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
21+
// dot(aBig, bSmall, inputPrecision="tf32")
22+
// let masked_nans = replaceNansWithZeros(small)
23+
// let big = dot(aBig, bBig, inputPrecision="tf32")
24+
// return big + masked_nans;
2325
class TF32x3 : public OpRewritePattern<DotOp> {
2426
public:
2527
using OpRewritePattern::OpRewritePattern;
@@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6264
InputPrecision::TF32,
6365
dotOp.getMaxNumImpreciseAcc());
6466
};
67+
auto replaceNansWithZeros = [&](Value value) -> Value {
68+
auto nans = rewriter.create<arith::CmpFOp>(
69+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
70+
auto zero = zeroLike(value);
71+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
72+
value);
73+
};
6574

6675
auto aBig = f32ToTF32(dotOp.getA());
6776
auto aSmall = sub(dotOp.getA(), aBig);
@@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {
7382

7483
auto dot1 = dot(aSmall, bBig, zero);
7584
auto dot2 = dot(aBig, bSmall, dot1);
76-
auto dot3 = dot(aBig, bBig, dot2);
85+
86+
// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
87+
// If rhs is +infinity, we will have:
88+
// +infinity * 1.0 = +infinity
89+
// +infinity * 0.0 = NaN
90+
// We would get the wrong result if we sum these partial products. Instead,
91+
// we must override any accumulated result if the last partial product is
92+
// non-finite.
93+
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
94+
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
7795

7896
auto sum = add(dot3, dotOp.getC());
7997

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
111111

112112
Value zero = builder.createWithStage<arith::ConstantIntOp>(
113113
forOp.getLoc(), stage, clusterId, 0, 32);
114+
114115
// Replace the load with insert/extract slice.
115116
builder.setInsertionPoint(loadOp);
116117
Location loc = loadOp.getLoc();
@@ -468,7 +469,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
468469
}
469470
});
470471

471-
loadsToPipeline.insert(&op);
472+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
473+
// loadsToPipeline.insert(&op);
472474
LoadInfo loadInfo;
473475
for (auto use : users) {
474476
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -508,6 +510,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
508510
getBlockedEncoding(loadOp, axisInfoAnalysis);
509511
}
510512
}
513+
514+
// TODO: b/381421713 - Remove this once pipelining is fixed.
515+
if (!loadInfo.sharedEncoding) continue;
516+
loadsToPipeline.insert(&op);
517+
511518
loadToInfo[&op] = loadInfo;
512519
}
513520
// Make sure all loads in loadsToPipeline are in loadToInfo.

0 commit comments

Comments
 (0)