Skip to content

Commit 25aa0c6

Browse files
[MLIR][TORCH] Add support for enable_gqa flag in SDPA op (#3950)
Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 7cea07c commit 25aa0c6

File tree

3 files changed

+153
-5
lines changed

3 files changed

+153
-5
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

+123-5
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,54 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
372372
return SmallVector<Value>(topkOp.getResults());
373373
}
374374

375+
static FailureOr<Value>
376+
repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter,
377+
Type resType, Value self, int64_t repeats,
378+
int64_t dim) {
379+
Location loc = op->getLoc();
380+
auto context = op->getContext();
381+
auto selfTy = cast<BaseTensorType>(self.getType());
382+
383+
int64_t inputRank = selfTy.getSizes().size();
384+
dim = toPositiveDim(dim, inputRank);
385+
Value dimValue =
386+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
387+
Value dimValuePlusOne =
388+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim + 1));
389+
390+
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
391+
if (failed(unsqueezedInfo))
392+
return rewriter.notifyMatchFailure(op,
393+
"cannot generate unsqueeze tensor op");
394+
self = *unsqueezedInfo;
395+
396+
Value constMinusOne =
397+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
398+
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
399+
expandShapeValueList[dim + 1] =
400+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(repeats));
401+
Value expandShapeList = rewriter.create<PrimListConstructOp>(
402+
loc, ListType::get(IntType::get(context)), expandShapeValueList);
403+
404+
SmallVector<int64_t> expandShape(inputRank + 1);
405+
for (int64_t i = 0; i <= dim; i++) {
406+
expandShape[i] = selfTy.getSizes()[i];
407+
}
408+
expandShape[dim + 1] = repeats;
409+
for (int64_t i = dim + 1; i < inputRank; i++) {
410+
expandShape[i + 1] = selfTy.getSizes()[i];
411+
}
412+
413+
BaseTensorType expandTy =
414+
rewriter.getType<ValueTensorType>(expandShape, selfTy.getOptionalDtype());
415+
Value expandSelf =
416+
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, expandShapeList);
417+
418+
Value result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
419+
dimValue, dimValuePlusOne);
420+
return result;
421+
}
422+
375423
namespace {
376424
template <typename AtenOpT>
377425
class ConvertAtenScatterOp : public OpConversionPattern<AtenOpT> {
@@ -1651,6 +1699,65 @@ class ConvertAtenScaledDotProductAttentionOp
16511699
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
16521700
public:
16531701
using OpConversionPattern::OpConversionPattern;
1702+
1703+
static LogicalResult
1704+
preProcessGroupQueryAttentionInput(AtenScaledDotProductAttentionOp op,
1705+
ConversionPatternRewriter &rewriter,
1706+
const TypeConverter *typeConverter,
1707+
Value query, Value &key, Value &value) {
1708+
auto queryTy = cast<ShapedType>(query.getType());
1709+
auto valueTy = cast<ShapedType>(value.getType());
1710+
auto keyTy = cast<ShapedType>(key.getType());
1711+
1712+
int64_t rank = queryTy.getRank();
1713+
1714+
int64_t qNumHeads = queryTy.getDimSize(rank - 3);
1715+
int64_t kNumHeads = valueTy.getDimSize(rank - 3);
1716+
int64_t vNumHeads = keyTy.getDimSize(rank - 3);
1717+
1718+
if (llvm::any_of(llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads},
1719+
[](int64_t d) { return d == Torch::kUnknownSize; })) {
1720+
return llvm::failure();
1721+
}
1722+
1723+
if (llvm::all_equal(
1724+
llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads}))
1725+
return llvm::success();
1726+
1727+
if ((qNumHeads % kNumHeads) && (qNumHeads % vNumHeads))
1728+
return llvm::failure();
1729+
1730+
int64_t repeatKeyShape = qNumHeads / kNumHeads;
1731+
int64_t repeatValueShape = qNumHeads / vNumHeads;
1732+
1733+
Location loc = op.getLoc();
1734+
FailureOr<Value> keyRepeated = repeatTensorElementsForDim(
1735+
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
1736+
op.getKey(),
1737+
/*repeats=*/repeatKeyShape, /*dim=*/rank - 3);
1738+
if (failed(keyRepeated))
1739+
return rewriter.notifyMatchFailure(
1740+
loc, "Failed to repeat the tensor elements for key.");
1741+
1742+
FailureOr<Value> valueRepeated = repeatTensorElementsForDim(
1743+
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
1744+
op.getValue(),
1745+
/*repeats=*/repeatValueShape, /*dim=*/rank - 3);
1746+
if (failed(valueRepeated))
1747+
return rewriter.notifyMatchFailure(
1748+
loc, "Failed to repeat the tensor elements for value.");
1749+
1750+
key = typeConverter->materializeTargetConversion(
1751+
rewriter, loc,
1752+
typeConverter->convertType(keyRepeated.value().getType()),
1753+
keyRepeated.value());
1754+
value = typeConverter->materializeTargetConversion(
1755+
rewriter, loc,
1756+
typeConverter->convertType(valueRepeated.value().getType()),
1757+
valueRepeated.value());
1758+
return success();
1759+
}
1760+
16541761
LogicalResult
16551762
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
16561763
ConversionPatternRewriter &rewriter) const override {
@@ -1795,11 +1902,6 @@ class ConvertAtenScaledDotProductAttentionOp
17951902
scaleFloat != 1.0)
17961903
return rewriter.notifyMatchFailure(loc, "only default scale supported");
17971904
}
1798-
bool isGQAEnabled;
1799-
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
1800-
isGQAEnabled)
1801-
return rewriter.notifyMatchFailure(
1802-
loc, "grouped query attention not supported");
18031905

18041906
if (queryTy.getRank() != valueTy.getRank() ||
18051907
queryTy.getRank() != keyTy.getRank())
@@ -1808,6 +1910,22 @@ class ConvertAtenScaledDotProductAttentionOp
18081910
if (queryTy.getRank() < 3)
18091911
return rewriter.notifyMatchFailure(op, "missing batch dimension");
18101912

1913+
bool isGQAEnabled;
1914+
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)))
1915+
return rewriter.notifyMatchFailure(
1916+
loc, "Expected enable_gqa flag to be constant bool");
1917+
1918+
// For the cases when `enable_gqa` flag is set to true, we have to
1919+
// pre-process the inputs specifically key and value by repeating the
1920+
// elements for the head dim.
1921+
// The reference code is available here:
1922+
// https://github.com/pytorch/pytorch/pull/132689/files#diff-e726853e9795dfb6c74ab1e10945f5d5f24540eb7bc633e5c76f69bc258f24d6R612
1923+
if (enableGQA) {
1924+
if (failed(preProcessGroupQueryAttentionInput(
1925+
op, rewriter, getTypeConverter(), query, key, value)))
1926+
return failure();
1927+
}
1928+
18111929
llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
18121930
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
18131931
reassociation.front().push_back(i);

projects/pt1/e2e_testing/xfail_sets.py

+3
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@
940940
"BernoulliFloatModule_basic",
941941
"UniformModule_basic",
942942
"UniformStaticShapeModule_basic",
943+
"ScaledDotProductAttentionGQAModule_basic",
943944
}
944945

945946
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3252,6 +3253,7 @@
32523253
"Aten_TrilinearModuleVaryingRanks_basic",
32533254
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
32543255
"Aten_TrilinearModuleZerodDimBug_basic",
3256+
"ScaledDotProductAttentionGQAModule_basic",
32553257
}
32563258

32573259
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3764,6 +3766,7 @@
37643766
"ScaledDotProductAttentionSameCausalModule_basic",
37653767
"ScaledDotProductAttentionSameDynamicModule_basic",
37663768
"ScaledDotProductAttentionSameModule_basic",
3769+
"ScaledDotProductAttentionGQAModule_basic",
37673770
}
37683771

37693772
ONNX_TOSA_CRASHING_SET = {

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

+27
Original file line numberDiff line numberDiff line change
@@ -5742,6 +5742,33 @@ def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
57425742
module.forward(query, key, value, mask)
57435743

57445744

5745+
class ScaledDotProductAttentionGQAModule(torch.nn.Module):
5746+
def __init__(self):
5747+
super().__init__()
5748+
5749+
@export
5750+
@annotate_args(
5751+
[
5752+
None,
5753+
([4, 32, 3, 8], torch.float32, True),
5754+
([4, 8, 3, 8], torch.float32, True),
5755+
([4, 8, 3, 8], torch.float32, True),
5756+
]
5757+
)
5758+
def forward(self, query, key, value):
5759+
return torch.ops.aten.scaled_dot_product_attention(
5760+
query, key, value, enable_gqa=True
5761+
)
5762+
5763+
5764+
@register_test_case(module_factory=lambda: ScaledDotProductAttentionGQAModule())
5765+
def ScaledDotProductAttentionGQAModule_basic(module, tu: TestUtils):
5766+
query = torch.randn(4, 32, 3, 8, dtype=torch.float32)
5767+
key = torch.randn(4, 8, 3, 8, dtype=torch.float32)
5768+
value = torch.randn(4, 8, 3, 8, dtype=torch.float32)
5769+
module.forward(query, key, value)
5770+
5771+
57455772
# ==============================================================================
57465773

57475774

0 commit comments

Comments
 (0)