Skip to content

Commit e14db45

Browse files
committed
OpenXLA-specific changes
1 parent 593cd04 commit e14db45

File tree

38 files changed

+3508
-919
lines changed

38 files changed

+3508
-919
lines changed

BUILD

+931
Large diffs are not rendered by default.

lib/Analysis/AxisInfo.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1011,6 +1011,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10111011
CastOpAxisInfoVisitor<arith::ExtUIOp>,
10121012
CastOpAxisInfoVisitor<arith::TruncIOp>,
10131013
CastOpAxisInfoVisitor<arith::IndexCastOp>,
1014+
CastOpAxisInfoVisitor<arith::IndexCastUIOp>,
10141015
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10151016
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10161017
CastOpAxisInfoVisitor<triton::BitcastOp>>();

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/TritonGPU/IR/Dialect.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -3065,6 +3065,11 @@ struct CanonicalizeConvertFromAlloc
30653065
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
30663066
if (!convert)
30673067
return failure();
3068+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
3069+
// to SharedEncoding, so we want to keep this layout conversion.
3070+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
3071+
convert.getSrc().getType().getEncoding()))
3072+
return failure();
30683073
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
30693074
op, op->getResult(0).getType(), convert.getSrc());
30703075
return mlir::success();
@@ -3127,13 +3132,13 @@ struct CanonicalizeConvertFromConvert
31273132
// heuristic to accommodate fused attention.
31283133
auto srcType = op.getSrc().getType();
31293134
auto dstType = op.getType();
3130-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
3131-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
3135+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
3136+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
31323137
return failure();
31333138

31343139
// for hopper MMAv3
3135-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
3136-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
3140+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
3141+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
31373142
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
31383143
return dot->hasTrait<OpTrait::DotLike>();
31393144
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+43-7
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace mlir {
2020
namespace triton {
2121
namespace gpu {
2222

23-
namespace {
24-
2523
// Get the highest version supported for the hardware and the dot.
2624
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2725
// List supported mma version in order of preference.
@@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4442
return 0;
4543
}
4644

47-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
48-
int numWarps) {
45+
SmallVector<unsigned>
46+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
4947
auto rank = shape.size();
5048
// Early exit for batched matmul
5149
if (rank == 3)
@@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
109107
}
110108

111109
SmallVector<unsigned, 2>
112-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
110+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
113111
const SmallVector<unsigned, 3> &instrShape) {
114112
SetVector<Operation *> slices;
115-
mlir::getForwardSlice(dotOp.getResult(), &slices);
113+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
116114
// Contains a chained dot. We prefer to assign warps to one axis
117115
// to facilitate use cases like flash attention, allowing reductions within
118116
// the same warp.
@@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
167165
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
168166
newLayout, SharedMemorySpace);
169167
rewriter.setInsertionPointAfterValue(arg);
168+
169+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
170+
// to SharedEncoding.
171+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
172+
argType.getEncoding())) {
173+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
174+
// then pass it to the LocalAllocOp.
175+
auto newArgType = RankedTensorType::get(
176+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
177+
auto dotOperandToBlockedCvt =
178+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
179+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
180+
dotOperandToBlockedCvt);
181+
}
182+
170183
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
171184
}
172185

173186
SmallVector<unsigned, 3>
174-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
187+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
175188
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
176189
switch (version) {
177190
case 2:
@@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
184197
}
185198
}
186199

200+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
201+
// extension.
202+
namespace {
203+
187204
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
188205
int computeCapability;
189206
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
190207

191208
static bool bwdFilter(Operation *op) {
209+
// Dot operand layout assignment to Predicates are not currently supported
210+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
211+
// condition limits visibility of the original bit-width so that predicate
212+
// are not considered, hence, kwidth can never be = 32.
213+
if (isa<arith::UIToFPOp>(op)) {
214+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
215+
if (srcType.isInteger(1))
216+
return false;
217+
}
192218
return op->getNumOperands() == 1 &&
193219
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
194220
isPureUnaryInlineAsm(op) ||
195221
op->getDialect()->getTypeID() ==
196222
mlir::TypeID::get<arith::ArithDialect>());
197223
}
198224

225+
public:
199226
// Finds the first different bitwidth in the chain of shape-preserving
200227
// unary ops that x depends on.
201228
// There are two primary scenarios:
@@ -805,6 +832,15 @@ class TritonGPUAccelerateMatmulPass
805832
}
806833
};
807834

835+
// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
836+
int computeOrigBitWidth(Value x) {
837+
return BlockedToMMA::computeOrigBitWidth(x);
838+
}
839+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
840+
int opIdx, bool allowTranspose) {
841+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
842+
}
843+
808844
} // namespace gpu
809845
} // namespace triton
810846
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace {
1717
// dot(a, b, inputPrecision="tf32x3") ->
1818
// let aBig = f32ToTF32(a), aSmall = a - aBig;
1919
// let bBig = f32ToTF32(b), bSmall = b - bBig;
20-
// dot(aSmall, bBig, inputPrecision="tf32") +
21-
// dot(aBig, bSmall, inputPrecision="tf32") +
22-
// dot(aBig, bBig, inputPrecision="tf32")
20+
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
21+
// dot(aBig, bSmall, inputPrecision="tf32")
22+
// let masked_nans = replaceNansWithZeros(small)
23+
// let big = dot(aBig, bBig, inputPrecision="tf32")
24+
// return big + masked_nans;
2325
class TF32x3 : public OpRewritePattern<DotOp> {
2426
public:
2527
using OpRewritePattern::OpRewritePattern;
@@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6264
InputPrecision::TF32,
6365
dotOp.getMaxNumImpreciseAcc());
6466
};
67+
auto replaceNansWithZeros = [&](Value value) -> Value {
68+
auto nans = rewriter.create<arith::CmpFOp>(
69+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
70+
auto zero = zeroLike(value);
71+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
72+
value);
73+
};
6574

6675
auto aBig = f32ToTF32(dotOp.getA());
6776
auto aSmall = sub(dotOp.getA(), aBig);
@@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {
7382

7483
auto dot1 = dot(aSmall, bBig, zero);
7584
auto dot2 = dot(aBig, bSmall, dot1);
76-
auto dot3 = dot(aBig, bBig, dot2);
85+
86+
// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
87+
// If rhs is +infinity, we will have:
88+
// +infinity * 1.0 = +infinity
89+
// +infinity * 0.0 = NaN
90+
// We would get the wrong result if we sum these partial products. Instead,
91+
// we must override any accumulated result if the last partial product is
92+
// non-finite.
93+
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
94+
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
7795

7896
auto sum = add(dot3, dotOp.getC());
7997

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
111111

112112
Value zero = builder.createWithStage<arith::ConstantIntOp>(
113113
forOp.getLoc(), stage, clusterId, 0, 32);
114+
114115
// Replace the load with insert/extract slice.
115116
builder.setInsertionPoint(loadOp);
116117
Location loc = loadOp.getLoc();
@@ -467,7 +468,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
467468
}
468469
});
469470

470-
loadsToPipeline.insert(&op);
471+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
472+
// loadsToPipeline.insert(&op);
471473
LoadInfo loadInfo;
472474
for (auto use : users) {
473475
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -507,6 +509,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
507509
getBlockedEncoding(loadOp, axisInfoAnalysis);
508510
}
509511
}
512+
513+
// TODO: b/381421713 - Remove this once pipelining is fixed.
514+
if (!loadInfo.sharedEncoding) continue;
515+
loadsToPipeline.insert(&op);
516+
510517
loadToInfo[&op] = loadInfo;
511518
}
512519
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::gpu::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
141141
type.getMutableMemory(), type.getAllocShape()),
142142
v, offsetsVal);
143143

144+
// We need to assign kwidth to zero in the case where the parent layout is
145+
// Blocked, otherwise the verifier emits a failure. The parent layout is
146+
// Blocked only when Tensor Cores are disabled.
147+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
148+
? 0
149+
: prefetchWidth / 8;
144150
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
145-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
151+
builder.getContext(), opIdx, dotEncoding, kwidth);
146152
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
147153
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
148154
newSmem);
@@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
191197
break;
192198
if (!op->getResult(0).hasOneUse())
193199
break;
200+
// Similar to issues faced in HoistLayoutConversion pattern in
201+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
202+
// predicates as they aren't supported in Triton when encoded with dot_op
203+
// layout.
204+
if (isa<arith::UIToFPOp>(op)) {
205+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
206+
if (srcType.isInteger(1))
207+
break;
208+
}
209+
// Propagation through ExpandDims is currently not supported. This blindly
210+
// replaces the encoding with dot encoding & but ExpandDims requires a
211+
// SliceEncoding. This could be rewritten to support it somehow, but I
212+
// don't think it's trivial & it's currently crashing.
213+
if (isa<ExpandDimsOp>(op)) {
214+
break;
215+
}
194216
rets.push_back(op->getOperand(0));
195217
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
196218
foundConvertFromShared = true;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

+19-11
Original file line numberDiff line numberDiff line change
@@ -958,18 +958,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
958958
} else {
959959
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
960960
return std::nullopt;
961-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
962-
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
963-
.getEncoding());
964-
if (!dotOpEnc)
961+
auto enc =
962+
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
963+
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
964+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
965+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
966+
auto order = ttg::getOrder(srcTy.getEncoding());
967+
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
968+
tempAttr = ttg::SharedEncodingAttr::get(
969+
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
970+
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
971+
} else if (enc.getAbstractAttribute().getName().str() ==
972+
"triton.gpu.sparse_dot_meta_encoding") {
973+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
974+
tempAttr = ttg::SharedEncodingAttr::get(
975+
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
976+
ttg::getOrder(srcTy.getEncoding()),
977+
ttg::getCTALayout(srcTy.getEncoding()));
978+
} else {
965979
return std::nullopt;
966-
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
967-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
968-
auto order = ttg::getOrder(srcTy.getEncoding());
969-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
970-
tempAttr = ttg::SharedEncodingAttr::get(
971-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
972-
bitWidth, /*needTrans=*/false);
980+
}
973981
}
974982
// Check that the shared encodings needed by the users are compatible.
975983
if (attr != nullptr && attr != tempAttr) {

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct FenceInsertionPass
4444
return;
4545
ModuleOp mod = getOperation();
4646
mod.walk([&](Operation *op) {
47-
if (!isa<ttng::WarpGroupDotOp>(op))
47+
if (!op->hasTrait<OpTrait::DotLike>())
4848
return WalkResult::advance();
4949
OpBuilder builder(op);
5050
auto a = op->getOperand(0);

0 commit comments

Comments
 (0)