Skip to content

Commit b39a103

Browse files
committed
OpenXLA-specific changes
1 parent fd02f65 commit b39a103

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2330
-203
lines changed

BUILD

+904
Large diffs are not rendered by default.

include/triton/Analysis/AxisInfo.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
180180
for (auto funcOp : llvm::reverse(sortedFuncs)) {
181181
initialize(funcOp);
182182
funcOp.walk([&](CallOpInterface callOp) {
183-
auto callee =
184-
dyn_cast<FunctionOpInterface>(callOp.resolveCallable(&symbolTable));
183+
auto callee = dyn_cast<FunctionOpInterface>(
184+
callOp.resolveCallableInTable(&symbolTable));
185185
update(callOp, callee);
186186
});
187187
}

include/triton/Analysis/Utility.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ template <typename T> class CallGraph {
316316
moduleOp.walk([&](Operation *op) {
317317
auto caller = op->getParentOfType<FunctionOpInterface>();
318318
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
319-
auto *callee = callOp.resolveCallable(&symbolTable);
319+
auto *callee = callOp.resolveCallableInTable(&symbolTable);
320320
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
321321
if (funcOp) {
322322
graph[caller].emplace_back(

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<I
9191
let results = (outs TT_FpIntTensor:$d);
9292

9393
let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";
94+
95+
let extraClassDeclaration = [{
96+
bool needsPartialAccumulator();
97+
}];
9498
}
9599

96100
def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,

lib/Analysis/Utility.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,11 @@ bool supportMMA(triton::DotOp op, int version) {
488488
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
489489
return false;
490490
auto retType = op.getType();
491+
RankedTensorType typeA = op.getA().getType();
492+
int k = typeA.getShape().back();
493+
// If k size is smaller than the native mma size, we cannot use MMA.
494+
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
495+
return false;
491496
auto retShapePerCTA = getShapePerCTA(retType);
492497
auto rank = retShapePerCTA.size();
493498
auto mod = op->getParentOfType<ModuleOp>();

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4040
auto ouEltTy = ouTensorTy.getElementType();
4141
if (inBitWidth == ouBitWidth)
4242
return values;
43-
if (inBitWidth == 16 && ouBitWidth == 32) {
43+
if ((inBitWidth == 16 && ouBitWidth == 32) ||
44+
(inBitWidth == 32 && ouBitWidth == 16)) {
4445
SmallVector<Value> ret;
4546
for (unsigned i = 0; i < values.size(); i += 8) {
4647
ret.push_back(values[i]);
@@ -610,10 +611,9 @@ struct IndexCastOpLowering
610611
if (targetBits == sourceBits)
611612
return {operands[0][0]};
612613
if (targetBits < sourceBits)
613-
return {rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, elemTy,
614-
operands[0][0])};
615-
return {
616-
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, elemTy, operands[0][0])};
614+
return {
615+
rewriter.create<LLVM::TruncOp>(op.getLoc(), elemTy, operands[0][0])};
616+
return {rewriter.create<LLVM::SExtOp>(op.getLoc(), elemTy, operands[0][0])};
617617
}
618618
};
619619

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
8787
// LLVM IR.
8888
if (type::isFloat8(elemType))
8989
elemType = rewriter.getIntegerType(8);
90-
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
9190
auto typeConverter = getTypeConverter();
91+
auto constOp = rewriter.create<LLVM::ConstantOp>(
92+
loc, typeConverter->convertType(elemType), val);
9293
auto llStruct = SplatOpConversion::convertSplatLikeOp(
9394
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
9495
rewriter.replaceOp(op, llStruct);

lib/Dialect/TritonGPU/IR/Dialect.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,11 @@ struct CanonicalizeConvertFromAlloc
27172717
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
27182718
if (!convert)
27192719
return failure();
2720+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
2721+
// to SharedEncoding, so we want to keep this layout conversion.
2722+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
2723+
convert.getSrc().getType().getEncoding()))
2724+
return failure();
27202725
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
27212726
op, op->getResult(0).getType(), convert.getSrc());
27222727
return mlir::success();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
153153
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
154154
newLayout, SharedMemorySpace);
155155
rewriter.setInsertionPointAfterValue(arg);
156+
157+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
158+
// to SharedEncoding.
159+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
160+
argType.getEncoding())) {
161+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
162+
// then pass it to the LocalAllocOp.
163+
auto newArgType = RankedTensorType::get(
164+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
165+
auto dotOperandToBlockedCvt =
166+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
167+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
168+
dotOperandToBlockedCvt);
169+
}
170+
156171
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
157172
}
158173

@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
162177
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
163178

164179
static bool bwdFilter(Operation *op) {
180+
// Dot operand layout assignment to Predicates are not currently supported
181+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
182+
// condition limits visibility of the original bit-width so that predicate
183+
// are not considered, hence, kwidth can never be = 32.
184+
if (isa<arith::UIToFPOp>(op)) {
185+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
186+
if (srcType.isInteger(1))
187+
return false;
188+
}
165189
return op->getNumOperands() == 1 &&
166190
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
167191
isPureUnaryInlineAsm(op) ||

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ namespace gpu {
1414
namespace {
1515
bool dotSupportsAccInitFlag(Operation *op) {
1616
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
17-
return isa<triton::nvidia_gpu::WarpGroupDotOp>(op);
17+
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
18+
// Partial accumulation would require a select op to handle the
19+
// initialization that would degrade the performance.
20+
return !wgDotOp.needsPartialAccumulator();
21+
}
22+
return false;
1823
}
1924

2025
std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
111111
PatternRewriter &rewriter) const override {
112112
// Only consider conversions to dot operand.
113113
auto cvtTy = cast<RankedTensorType>(cvt.getType());
114-
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
114+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
115+
if (!dotOpEnc)
115116
return failure();
116117

117118
auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
126127
[](Type ty) { return isa<RankedTensorType>(ty); }))
127128
return failure();
128129

130+
// Quick handling to fix loading issues when computing the original
131+
// bitwidth is unable to realize that there is a mixed-precision dot
132+
// (hence kWidth = 1) but wants to hoist through the type conversion.
133+
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134+
return failure();
135+
129136
// Only consider custom conversions or arith ops.
130137
// TODO(jlebar): Is this too restrictive?
131138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
138145
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
139146
return failure();
140147

148+
// Don't hoist through u1 -> fp casts as they aren't supported in
149+
// ElementwiseOpToLLVM::reorderValues().
150+
if (isa<arith::UIToFPOp>(src)) {
151+
Type srcType = getElementTypeOrSelf(src->getOperand(0));
152+
if (srcType.isInteger(1))
153+
return failure();
154+
}
155+
141156
// Check that the conversion is transitively dependent on a load, and all
142157
// operations between the load and the conversion are layout preserving.
143158
//

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140
type.getMemorySpace()),
141141
v, offsetsVal);
142142

143+
// We need to assign kwidth to zero in the case where the parent layout is
144+
// Blocked, otherwise the verifier emits a failure. The parent layout is
145+
// Blocked only when Tensor Cores are disabled.
146+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+
? 0
148+
: prefetchWidth / 8;
143149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
144-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
150+
builder.getContext(), opIdx, dotEncoding, kwidth);
145151
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
146152
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
147153
newSmem);
@@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
187193
break;
188194
if (!op->getResult(0).hasOneUse())
189195
break;
196+
// Similar to issues faced in HoistLayoutConversion pattern in
197+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
198+
// predicates as they aren't supported in Triton when encoded with dot_op
199+
// layout.
200+
if (isa<arith::UIToFPOp>(op)) {
201+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
202+
if (srcType.isInteger(1))
203+
break;
204+
}
190205
rets.push_back(op->getOperand(0));
191206
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
192207
foundConvertFromShared = true;

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ void WarpGroupDotOp::getEffects(
7070
mlir::triton::gpu::SharedMemory::get());
7171
}
7272

73+
bool WarpGroupDotOp::needsPartialAccumulator() {
74+
const auto &a = getA();
75+
const auto &d = getD();
76+
auto aTensorTy = cast<TensorOrMemDesc>(a.getType());
77+
auto aElTy = cast<TensorOrMemDesc>(a.getType()).getElementType();
78+
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
79+
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
80+
bool accFP32 = cast<TensorOrMemDesc>(d.getType()).getElementType().isF32();
81+
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();
82+
return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1];
83+
}
84+
7385
// -- WarpGroupDotWaitOp --
7486
LogicalResult WarpGroupDotWaitOp::inferReturnTypes(
7587
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,

python/BUILD

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# NOTE: Do not depend on any targets from this directory,
2+
# but use //third_party/py/triton instead.
3+
4+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
5+
6+
package(
7+
default_applicable_licenses = ["//:license"],
8+
default_visibility = [
9+
"//third_party/py/triton:__pkg__",
10+
"@triton//python:__subpackages__",
11+
],
12+
)
13+
14+
cc_library(
15+
name = "passes",
16+
hdrs = ["src/passes.h"],
17+
includes = ["src"],
18+
visibility = ["@triton//third_party:__subpackages__"],
19+
)
20+
21+
pybind_extension(
22+
name = "libtriton",
23+
srcs = [
24+
"src/interpreter.cc",
25+
"src/ir.cc",
26+
"src/llvm.cc",
27+
"src/main.cc",
28+
"src/passes.cc",
29+
],
30+
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
31+
deps = [
32+
":passes",
33+
"@llvm-project//llvm:Core",
34+
"@llvm-project//llvm:IPO",
35+
"@llvm-project//llvm:IRReader",
36+
"@llvm-project//llvm:InstCombine",
37+
"@llvm-project//llvm:Linker",
38+
"@llvm-project//llvm:MC",
39+
"@llvm-project//llvm:Passes",
40+
"@llvm-project//llvm:Support",
41+
"@llvm-project//llvm:Target",
42+
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
43+
"@llvm-project//mlir:BytecodeWriter",
44+
"@llvm-project//mlir:ControlFlowDialect",
45+
"@llvm-project//mlir:ConversionPasses",
46+
"@llvm-project//mlir:IR",
47+
"@llvm-project//mlir:IndexDialect",
48+
"@llvm-project//mlir:LLVMDialect",
49+
"@llvm-project//mlir:LLVMIRTransforms",
50+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
51+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
52+
"@llvm-project//mlir:Parser",
53+
"@llvm-project//mlir:Pass",
54+
"@llvm-project//mlir:Support",
55+
"@llvm-project//mlir:ToLLVMIRTranslation",
56+
"@llvm-project//mlir:Transforms",
57+
"//:TritonAnalysis",
58+
"//:TritonDialects",
59+
"//:TritonGPUToLLVM",
60+
"//:TritonGPUTransforms",
61+
"//:TritonHSACO",
62+
"//:TritonLLVMIR",
63+
"//:TritonNvidiaGPUTransforms",
64+
"//:TritonPTX",
65+
"//:TritonToTritonGPU",
66+
"//:TritonTools",
67+
"//:TritonTransforms",
68+
"@triton//third_party/nvidia:triton_nvidia",
69+
],
70+
)
71+
72+
filegroup(
73+
name = "files",
74+
srcs = glob(
75+
include = ["triton/**/*.py"],
76+
),
77+
)

python/src/llvm.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp
1+
#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp
22
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
33
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
44
#include "triton/Tools/Sys/GetEnv.hpp"
@@ -346,8 +346,8 @@ void init_triton_llvm(py::module &&m) {
346346
// and break the lowering of some target specific intrinsics.
347347
std::unique_ptr<TargetMachine> targetMachine = nullptr;
348348
if (!arch.empty() && pluginFile.empty())
349-
targetMachine = std::move(
350-
createTargetMachine(mod, arch, enable_fp_fusion, features));
349+
targetMachine =
350+
createTargetMachine(mod, arch, enable_fp_fusion, features);
351351
PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions,
352352
std::nullopt, instrCbPtr);
353353

python/test/regression/BUILD

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")
2+
3+
package(
4+
default_applicable_licenses = ["//:license"],
5+
)
6+
7+
pytest_multi_tests(
8+
name = "tests",
9+
size = "large",
10+
srcs = ["conftest.py"],
11+
shard_count = 10,
12+
tags = [
13+
"config-cuda-only",
14+
"requires-gpu-sm80",
15+
],
16+
tests = glob(
17+
include = ["test_*.py"],
18+
exclude = [
19+
"test_performance.py", #TODO(b/321005767): fix failing test
20+
],
21+
),
22+
deps = [
23+
"//third_party/py/torch:pytorch",
24+
"//third_party/py/triton",
25+
],
26+
)

python/test/regression/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# content of conftest.py
2+
3+
import pytest
4+
5+
6+
def pytest_addoption(parser):
7+
parser.addoption("--device", action="store", default='cuda')
8+
9+
10+
@pytest.fixture
11+
def device(request):
12+
return request.config.getoption("--device")

0 commit comments

Comments
 (0)