Skip to content

Commit 7558fc1

Browse files
committed
OpenXLA-specific changes
1 parent 9732c04 commit 7558fc1

File tree

40 files changed

+3428
-914
lines changed

40 files changed

+3428
-914
lines changed

BUILD

+911
Large diffs are not rendered by default.

lib/Analysis/AxisInfo.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1011,6 +1011,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10111011
CastOpAxisInfoVisitor<arith::ExtUIOp>,
10121012
CastOpAxisInfoVisitor<arith::TruncIOp>,
10131013
CastOpAxisInfoVisitor<arith::IndexCastOp>,
1014+
CastOpAxisInfoVisitor<arith::IndexCastUIOp>,
10141015
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10151016
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10161017
CastOpAxisInfoVisitor<triton::BitcastOp>>();

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/Triton/Transforms/Combine.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat<
1717
[(Constraint<CPred<"isZero($0)">> $c),
1818
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
1919
def CombineDotAddFPattern : Pat<
20-
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
20+
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm),
2121
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
2222
[(Constraint<CPred<"isZero($0)">> $c),
2323
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
@@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat<
2929
[(Constraint<CPred<"isZero($0)">> $c),
3030
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
3131
def CombineDotAddFRevPattern : Pat<
32-
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
32+
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm),
3333
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
3434
[(Constraint<CPred<"isZero($0)">> $c),
3535
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),

lib/Dialect/TritonGPU/IR/Dialect.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,11 @@ struct CanonicalizeConvertFromAlloc
26672667
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
26682668
if (!convert)
26692669
return failure();
2670+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
2671+
// to SharedEncoding, so we want to keep this layout conversion.
2672+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
2673+
convert.getSrc().getType().getEncoding()))
2674+
return failure();
26702675
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
26712676
op, op->getResult(0).getType(), convert.getSrc());
26722677
return mlir::success();
@@ -2729,13 +2734,13 @@ struct CanonicalizeConvertFromConvert
27292734
// heuristic to accommodate fused attention.
27302735
auto srcType = op.getSrc().getType();
27312736
auto dstType = op.getType();
2732-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
2733-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
2737+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
2738+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
27342739
return failure();
27352740

27362741
// for hopper MMAv3
2737-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
2738-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
2742+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
2743+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
27392744
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
27402745
return dot->hasTrait<OpTrait::DotLike>();
27412746
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+52-12
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ namespace mlir {
1919
namespace triton {
2020
namespace gpu {
2121

22-
namespace {
23-
2422
// Get the highest version supported for the hardware and the dot.
2523
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2624
// List supported mma version in order of preference.
@@ -43,8 +41,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4341
return 0;
4442
}
4543

46-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
47-
int numWarps) {
44+
SmallVector<unsigned>
45+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
4846
auto rank = shape.size();
4947
// Early exit for batched matmul
5048
if (rank == 3)
@@ -57,9 +55,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
5755
auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
5856
bool hasChainedDot = false;
5957
for (Operation *op : slices) {
60-
if (isa<DotOp>(op) && (op != dotOp)) {
61-
auto chainedDot = cast<DotOp>(op);
62-
auto resTy = chainedDot.getResult().getType();
58+
if (dotOp->getName() == op->getName() && op != dotOp) {
59+
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
6360
if (resTy.getRank() != rank) {
6461
continue;
6562
}
@@ -108,12 +105,17 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
108105
}
109106

110107
SmallVector<unsigned, 2>
111-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
108+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
112109
const SmallVector<unsigned, 3> &instrShape) {
113110
SetVector<Operation *> slices;
114-
mlir::getForwardSlice(dotOp.getResult(), &slices);
115-
if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
116-
slices.end())
111+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
112+
if (llvm::find_if(slices, [&](Operation *op) {
113+
return dotOp->getName() == op->getName() ||
114+
// Contains a chained dot. We prefer to assign warps to one axis
115+
// to facilitate use cases like flash attention, allowing reductions
116+
// within the same warp.
117+
op->hasTrait<OpTrait::DotLike>();
118+
}) != slices.end())
117119
return {(unsigned)numWarps, 1};
118120

119121
// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
@@ -162,11 +164,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
162164
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
163165
newLayout, SharedMemorySpace);
164166
rewriter.setInsertionPointAfterValue(arg);
167+
168+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
169+
// to SharedEncoding.
170+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
171+
argType.getEncoding())) {
172+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
173+
// then pass it to the LocalAllocOp.
174+
auto newArgType = RankedTensorType::get(
175+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
176+
auto dotOperandToBlockedCvt =
177+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
178+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
179+
dotOperandToBlockedCvt);
180+
}
181+
165182
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
166183
}
167184

168185
SmallVector<unsigned, 3>
169-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
186+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
170187
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
171188
switch (version) {
172189
case 2:
@@ -179,18 +196,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
179196
}
180197
}
181198

199+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
200+
// extension.
201+
namespace {
202+
182203
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
183204
int computeCapability;
184205
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
185206

186207
static bool bwdFilter(Operation *op) {
208+
// Dot operand layout assignment to Predicates are not currently supported
209+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
210+
// condition limits visibility of the original bit-width so that predicate
211+
// are not considered, hence, kwidth can never be = 32.
212+
if (isa<arith::UIToFPOp>(op)) {
213+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
214+
if (srcType.isInteger(1))
215+
return false;
216+
}
187217
return op->getNumOperands() == 1 &&
188218
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
189219
isPureUnaryInlineAsm(op) ||
190220
op->getDialect()->getTypeID() ==
191221
mlir::TypeID::get<arith::ArithDialect>());
192222
}
193223

224+
public:
194225
// Finds the first different bitwidth in the chain of shape-preserving
195226
// unary ops that x depends on.
196227
// There are two primary scenarios:
@@ -595,6 +626,15 @@ class TritonGPUAccelerateMatmulPass
595626
}
596627
};
597628

629+
// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
630+
int computeOrigBitWidth(Value x) {
631+
return BlockedToMMA::computeOrigBitWidth(x);
632+
}
633+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
634+
int opIdx, bool allowTranspose) {
635+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
636+
}
637+
598638
} // namespace gpu
599639
} // namespace triton
600640
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

+22-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace {
1717
// dot(a, b, inputPrecision="tf32x3") ->
1818
// let aBig = f32ToTF32(a), aSmall = a - aBig;
1919
// let bBig = f32ToTF32(b), bSmall = b - bBig;
20-
// dot(aSmall, bBig, inputPrecision="tf32") +
21-
// dot(aBig, bSmall, inputPrecision="tf32") +
22-
// dot(aBig, bBig, inputPrecision="tf32")
20+
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
21+
// dot(aBig, bSmall, inputPrecision="tf32")
22+
// let masked_nans = replaceNansWithZeros(small)
23+
// let big = dot(aBig, bBig, inputPrecision="tf32")
24+
// return big + masked_nans;
2325
class TF32x3 : public OpRewritePattern<DotOp> {
2426
public:
2527
using OpRewritePattern::OpRewritePattern;
@@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6264
InputPrecision::TF32,
6365
dotOp.getMaxNumImpreciseAcc());
6466
};
67+
auto replaceNansWithZeros = [&](Value value) -> Value {
68+
auto nans = rewriter.create<arith::CmpFOp>(
69+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
70+
auto zero = zeroLike(value);
71+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
72+
value);
73+
};
6574

6675
auto aBig = f32ToTF32(dotOp.getA());
6776
auto aSmall = sub(dotOp.getA(), aBig);
@@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {
7382

7483
auto dot1 = dot(aSmall, bBig, zero);
7584
auto dot2 = dot(aBig, bSmall, dot1);
76-
auto dot3 = dot(aBig, bBig, dot2);
85+
86+
// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
87+
// If rhs is +infinity, we will have:
88+
// +infinity * 1.0 = +infinity
89+
// +infinity * 0.0 = NaN
90+
// We would get the wrong result if we sum these partial products. Instead,
91+
// we must override any accumulated result if the last partial product is
92+
// non-finite.
93+
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
94+
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
7795

7896
auto sum = add(dot3, dotOp.getC());
7997

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
111111
PatternRewriter &rewriter) const override {
112112
// Only consider conversions to dot operand.
113113
auto cvtTy = cast<RankedTensorType>(cvt.getType());
114-
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
114+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
115+
if (!dotOpEnc)
115116
return failure();
116117

117118
auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
126127
[](Type ty) { return isa<RankedTensorType>(ty); }))
127128
return failure();
128129

130+
// Quick handling to fix loading issues when computing the original
131+
// bitwidth is unable to realize that there is a mixed-precision dot
132+
// (hence kWidth = 1) but wants to hoist through the type conversion.
133+
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134+
return failure();
135+
129136
// Only consider custom conversions or arith ops.
130137
// TODO(jlebar): Is this too restrictive?
131138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
138145
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
139146
return failure();
140147

148+
// Don't hoist through u1 -> fp casts as they aren't supported in
149+
// ElementwiseOpToLLVM::reorderValues().
150+
if (isa<arith::UIToFPOp>(src)) {
151+
Type srcType = getElementTypeOrSelf(src->getOperand(0));
152+
if (srcType.isInteger(1))
153+
return failure();
154+
}
155+
141156
// Check that the conversion is transitively dependent on a load, and all
142157
// operations between the load and the conversion are layout preserving.
143158
//

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
464464
}
465465
});
466466

467-
loadsToPipeline.insert(&op);
467+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
468+
// loadsToPipeline.insert(&op);
468469
LoadInfo loadInfo;
469470
for (auto use : users) {
470471
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -497,6 +498,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
497498
getBlockedEncoding(loadOp, axisInfoAnalysis);
498499
}
499500
}
501+
502+
// TODO: b/381421713 - Remove this once pipelining is fixed.
503+
if (!loadInfo.sharedEncoding) continue;
504+
loadsToPipeline.insert(&op);
505+
500506
loadToInfo[&op] = loadInfo;
501507
}
502508
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140
type.getMemorySpace()),
141141
v, offsetsVal);
142142

143+
// We need to assign kwidth to zero in the case where the parent layout is
144+
// Blocked, otherwise the verifier emits a failure. The parent layout is
145+
// Blocked only when Tensor Cores are disabled.
146+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+
? 0
148+
: prefetchWidth / 8;
143149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
144-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
150+
builder.getContext(), opIdx, dotEncoding, kwidth);
145151
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
146152
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
147153
newSmem);
@@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
190196
break;
191197
if (!op->getResult(0).hasOneUse())
192198
break;
199+
// Similar to issues faced in HoistLayoutConversion pattern in
200+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
201+
// predicates as they aren't supported in Triton when encoded with dot_op
202+
// layout.
203+
if (isa<arith::UIToFPOp>(op)) {
204+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
205+
if (srcType.isInteger(1))
206+
break;
207+
}
208+
// Propagation through ExpandDims is currently not supported. This blindly
209+
// replaces the encoding with dot encoding & but ExpandDims requires a
210+
// SliceEncoding. This could be rewritten to support it somehow, but I
211+
// don't think it's trivial & it's currently crashing.
212+
if (isa<ExpandDimsOp>(op)) {
213+
break;
214+
}
193215
rets.push_back(op->getOperand(0));
194216
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
195217
foundConvertFromShared = true;

0 commit comments

Comments
 (0)