@@ -372,6 +372,54 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
372
372
return SmallVector<Value>(topkOp.getResults ());
373
373
}
374
374
375
+ static FailureOr<Value>
376
+ repeatTensorElementsForDim (Operation *op, ConversionPatternRewriter &rewriter,
377
+ Type resType, Value self, int64_t repeats,
378
+ int64_t dim) {
379
+ Location loc = op->getLoc ();
380
+ auto context = op->getContext ();
381
+ auto selfTy = cast<BaseTensorType>(self.getType ());
382
+
383
+ int64_t inputRank = selfTy.getSizes ().size ();
384
+ dim = toPositiveDim (dim, inputRank);
385
+ Value dimValue =
386
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (dim));
387
+ Value dimValuePlusOne =
388
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (dim + 1 ));
389
+
390
+ auto unsqueezedInfo = unsqueezeTensor (rewriter, op, self, dimValuePlusOne);
391
+ if (failed (unsqueezedInfo))
392
+ return rewriter.notifyMatchFailure (op,
393
+ " cannot generate unsqueeze tensor op" );
394
+ self = *unsqueezedInfo;
395
+
396
+ Value constMinusOne =
397
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (-1 ));
398
+ SmallVector<Value> expandShapeValueList (inputRank + 1 , constMinusOne);
399
+ expandShapeValueList[dim + 1 ] =
400
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (repeats));
401
+ Value expandShapeList = rewriter.create <PrimListConstructOp>(
402
+ loc, ListType::get (IntType::get (context)), expandShapeValueList);
403
+
404
+ SmallVector<int64_t > expandShape (inputRank + 1 );
405
+ for (int64_t i = 0 ; i <= dim; i++) {
406
+ expandShape[i] = selfTy.getSizes ()[i];
407
+ }
408
+ expandShape[dim + 1 ] = repeats;
409
+ for (int64_t i = dim + 1 ; i < inputRank; i++) {
410
+ expandShape[i + 1 ] = selfTy.getSizes ()[i];
411
+ }
412
+
413
+ BaseTensorType expandTy =
414
+ rewriter.getType <ValueTensorType>(expandShape, selfTy.getOptionalDtype ());
415
+ Value expandSelf =
416
+ rewriter.create <AtenBroadcastToOp>(loc, expandTy, self, expandShapeList);
417
+
418
+ Value result = rewriter.create <PrimsCollapseOp>(loc, resType, expandSelf,
419
+ dimValue, dimValuePlusOne);
420
+ return result;
421
+ }
422
+
375
423
namespace {
376
424
template <typename AtenOpT>
377
425
class ConvertAtenScatterOp : public OpConversionPattern <AtenOpT> {
@@ -1651,6 +1699,65 @@ class ConvertAtenScaledDotProductAttentionOp
1651
1699
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
1652
1700
public:
1653
1701
using OpConversionPattern::OpConversionPattern;
1702
+
1703
+ static LogicalResult
1704
+ preProcessGroupQueryAttentionInput (AtenScaledDotProductAttentionOp op,
1705
+ ConversionPatternRewriter &rewriter,
1706
+ const TypeConverter *typeConverter,
1707
+ Value query, Value &key, Value &value) {
1708
+ auto queryTy = cast<ShapedType>(query.getType ());
1709
+ auto valueTy = cast<ShapedType>(value.getType ());
1710
+ auto keyTy = cast<ShapedType>(key.getType ());
1711
+
1712
+ int64_t rank = queryTy.getRank ();
1713
+
1714
+ int64_t qNumHeads = queryTy.getDimSize (rank - 3 );
1715
+ int64_t kNumHeads = valueTy.getDimSize (rank - 3 );
1716
+ int64_t vNumHeads = keyTy.getDimSize (rank - 3 );
1717
+
1718
+ if (llvm::any_of (llvm::ArrayRef<int64_t >{qNumHeads, kNumHeads , vNumHeads},
1719
+ [](int64_t d) { return d == Torch::kUnknownSize ; })) {
1720
+ return llvm::failure ();
1721
+ }
1722
+
1723
+ if (llvm::all_equal (
1724
+ llvm::ArrayRef<int64_t >{qNumHeads, kNumHeads , vNumHeads}))
1725
+ return llvm::success ();
1726
+
1727
+ if ((qNumHeads % kNumHeads ) && (qNumHeads % vNumHeads))
1728
+ return llvm::failure ();
1729
+
1730
+ int64_t repeatKeyShape = qNumHeads / kNumHeads ;
1731
+ int64_t repeatValueShape = qNumHeads / vNumHeads;
1732
+
1733
+ Location loc = op.getLoc ();
1734
+ FailureOr<Value> keyRepeated = repeatTensorElementsForDim (
1735
+ op.getOperation (), rewriter, /* resType=*/ op.getQuery ().getType (),
1736
+ op.getKey (),
1737
+ /* repeats=*/ repeatKeyShape, /* dim=*/ rank - 3 );
1738
+ if (failed (keyRepeated))
1739
+ return rewriter.notifyMatchFailure (
1740
+ loc, " Failed to repeat the tensor elements for key." );
1741
+
1742
+ FailureOr<Value> valueRepeated = repeatTensorElementsForDim (
1743
+ op.getOperation (), rewriter, /* resType=*/ op.getQuery ().getType (),
1744
+ op.getValue (),
1745
+ /* repeats=*/ repeatValueShape, /* dim=*/ rank - 3 );
1746
+ if (failed (valueRepeated))
1747
+ return rewriter.notifyMatchFailure (
1748
+ loc, " Failed to repeat the tensor elements for value." );
1749
+
1750
+ key = typeConverter->materializeTargetConversion (
1751
+ rewriter, loc,
1752
+ typeConverter->convertType (keyRepeated.value ().getType ()),
1753
+ keyRepeated.value ());
1754
+ value = typeConverter->materializeTargetConversion (
1755
+ rewriter, loc,
1756
+ typeConverter->convertType (valueRepeated.value ().getType ()),
1757
+ valueRepeated.value ());
1758
+ return success ();
1759
+ }
1760
+
1654
1761
LogicalResult
1655
1762
matchAndRewrite (AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
1656
1763
ConversionPatternRewriter &rewriter) const override {
@@ -1795,11 +1902,6 @@ class ConvertAtenScaledDotProductAttentionOp
1795
1902
scaleFloat != 1.0 )
1796
1903
return rewriter.notifyMatchFailure (loc, " only default scale supported" );
1797
1904
}
1798
- bool isGQAEnabled;
1799
- if (!matchPattern (enableGQA, m_TorchConstantBool (&isGQAEnabled)) ||
1800
- isGQAEnabled)
1801
- return rewriter.notifyMatchFailure (
1802
- loc, " grouped query attention not supported" );
1803
1905
1804
1906
if (queryTy.getRank () != valueTy.getRank () ||
1805
1907
queryTy.getRank () != keyTy.getRank ())
@@ -1808,6 +1910,22 @@ class ConvertAtenScaledDotProductAttentionOp
1808
1910
if (queryTy.getRank () < 3 )
1809
1911
return rewriter.notifyMatchFailure (op, " missing batch dimension" );
1810
1912
1913
+ bool isGQAEnabled;
1914
+ if (!matchPattern (enableGQA, m_TorchConstantBool (&isGQAEnabled)))
1915
+ return rewriter.notifyMatchFailure (
1916
+ loc, " Expected enable_gqa flag to be constant bool" );
1917
+
1918
+ // For the cases when `enable_gqa` flag is set to true, we have to
1919
+ // pre-process the inputs specifically key and value by repeating the
1920
+ // elements for the head dim.
1921
+ // The reference code is available here:
1922
+ // https://github.com/pytorch/pytorch/pull/132689/files#diff-e726853e9795dfb6c74ab1e10945f5d5f24540eb7bc633e5c76f69bc258f24d6R612
1923
+ if (enableGQA) {
1924
+ if (failed (preProcessGroupQueryAttentionInput (
1925
+ op, rewriter, getTypeConverter (), query, key, value)))
1926
+ return failure ();
1927
+ }
1928
+
1811
1929
llvm::SmallVector<ReassociationIndices, 3 > reassociation (3 );
1812
1930
for (int i = 0 , s = valueTy.getRank () - 2 ; i < s; ++i)
1813
1931
reassociation.front ().push_back (i);
0 commit comments