Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch]Decompose AtenScaledDotProductAttentionOp #3461

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include <optional>
#include <set>

using namespace mlir;
Expand Down Expand Up @@ -8006,6 +8007,153 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp
};
} // namespace

namespace {
class DecomposeAtenScaledDotProductAttentionOp
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
public:
using OpRewritePattern<AtenScaledDotProductAttentionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value query = op.getQuery();
auto queryTy = cast<BaseTensorType>(query.getType());
Value key = op.getKey();
auto keyTy = cast<BaseTensorType>(key.getType());
Value val = op.getValue();
auto valTy = cast<BaseTensorType>(val.getType());

auto resTy = cast<BaseTensorType>(op.getType());
if (!queryTy.hasDtype() || !keyTy.hasDtype() || !valTy.hasDtype() ||
!resTy.hasDtype())
return op.emitError("Types of Q, K, V and result "
"are expected to have dtype.");

if (!isa<mlir::FloatType>(queryTy.getDtype()) ||
!isa<mlir::FloatType>(keyTy.getDtype()) ||
!isa<mlir::FloatType>(valTy.getDtype()) ||
!isa<mlir::FloatType>(resTy.getDtype()))
return op.emitError("Q, K, V and result "
"are expected to have float dtype.");

bool isCausal = false;
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)))
return op.emitError("is_causal must be a Scalar constant");

double dropoutP;
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)))
return op.emitError("dropout_p must be a Scalar constant");

if (dropoutP != 0.0f)
return op.emitError("Dropout is NOT supported");

Value mask = op.getAttnMask();
auto maskTy = dyn_cast<BaseTensorType>(mask.getType());
if (maskTy) {
auto maskDty = maskTy.getOptionalDtype();
if (!maskTy.hasDtype() ||
(!isa<mlir::FloatType>(maskDty) && !maskDty.isSignlessInteger(1))) {
return op.emitError("attn_mask must be a tensor of "
"boolean or float");
}

if (isCausal)
return op.emitError("attn_mask and is_causal must be set exclusively.");
}

if (!keyTy.hasSizes())
return op.emitError("K must be a ranked tensor.");

SmallVector<int64_t, 6> transShape(keyTy.getSizes());
int64_t tmp = transShape.end()[-2];
transShape.end()[-2] = transShape.end()[-1];
transShape.end()[-1] = tmp;
auto transTy = keyTy.getWithSizesAndDtype(llvm::ArrayRef(transShape),
keyTy.getOptionalDtype());

Value minusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value minusTwo = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value transTensor = rewriter.create<AtenTransposeIntOp>(loc, transTy, key,
minusOne, minusTwo);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));

if (!queryTy.hasSizes())
return op.emitError("Q must be a ranked tensor.");
SmallVector<int64_t, 6> qkShape(queryTy.getSizes());
qkShape.end()[-1] = transShape.end()[-1];
auto qkTy = queryTy.getWithSizesAndDtype(llvm::ArrayRef(qkShape),
queryTy.getDtype());
Value qkTensor =
rewriter.create<AtenMatmulOp>(loc, qkTy, query, transTensor);

Value scale = op.getScale();
auto scaleTy = dyn_cast<mlir::FloatType>(scale.getType());
if (!scaleTy) {
Value lastDimSizeOfQ =
rewriter.create<AtenSizeIntOp>(loc, query, minusOne);
Value sqrtVal = rewriter.create<AtenSqrtIntOp>(loc, lastDimSizeOfQ);
scale = rewriter.create<AtenDivOp>(loc, one, sqrtVal);
}

auto noneSizeFloatType =
queryTy.getWithSizesAndDtype(std::nullopt, queryTy.getDtype());
Value scaledQKTensor =
rewriter.create<AtenMulScalarOp>(loc, qkTy, qkTensor, scale);

Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value maskedTensor = scaledQKTensor;
if (maskTy && isa<mlir::FloatType>(maskTy.getDtype())) {
maskedTensor = rewriter.create<AtenAddTensorOp>(loc, qkTy, scaledQKTensor,
mask, one);
} else if (maskTy || isCausal) {
Value firstDimSizeOfMask =
rewriter.create<AtenSizeIntOp>(loc, qkTensor, minusTwo);
Value secondDimSizeOfMask =
rewriter.create<AtenSizeIntOp>(loc, qkTensor, minusOne);
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(firstDimSizeOfMask.getType()),
ValueRange({firstDimSizeOfMask, secondDimSizeOfMask}));
Value dtypeInt = rewriter.create<PrimDtypeOp>(loc, scaledQKTensor);
Value zeros = rewriter.create<AtenZerosOp>(
loc, noneSizeFloatType, dimList, dtypeInt, none, none, none);
if (isCausal) {
auto noneSizeBoolType =
queryTy.getWithSizesAndDtype(std::nullopt, rewriter.getI1Type());
Value ones = rewriter.create<AtenOnesOp>(
loc, noneSizeBoolType, dimList,
getDtypeIntValueForType(rewriter, loc, rewriter.getI1Type()), none,
none, none);
mask = rewriter.create<AtenTrilOp>(loc, ones.getType(), ones, zero);
}

Value notMask =
rewriter.create<AtenLogicalNotOp>(loc, mask.getType(), mask);
auto dType = cast<mlir::FloatType>(queryTy.getDtype());
Value minimalVal = rewriter.create<Torch::ConstantFloatOp>(
loc, llvm::APFloat::getInf(dType.getFloatSemantics(), true));

maskTy = cast<BaseTensorType>(zeros.getType());

mask = rewriter.create<AtenMaskedFill_ScalarOp>(loc, maskTy, zeros,
notMask, minimalVal);
maskedTensor = rewriter.create<AtenAddTensorOp>(loc, qkTy, scaledQKTensor,
mask, one);
}

Value softmaxTensor = rewriter.create<AtenSoftmaxIntOp>(
loc, qkTy, maskedTensor, minusOne, none);

rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, resTy, softmaxTensor, val);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -8240,6 +8388,9 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);

addPatternIfTargetOpIsIllegal<DecomposeAtenScaledDotProductAttentionOp>(
patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenReshapeAsOp>();
target.addIllegalOp<AtenTriuOp>();
target.addIllegalOp<AtenLinalgNormOp>();
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
3 changes: 2 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@
"ElementwiseLogSigmoidModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseSoftshrinkStaticModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
}

STABLEHLO_CRASHING_SET = set()
Expand Down Expand Up @@ -1943,6 +1944,7 @@
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (
Expand Down Expand Up @@ -1975,7 +1977,6 @@
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/python/torch_mlir/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _get_for_tracing(
"aten.adaptive_avg_pool1d",
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
"aten.scaled_dot_product_attention",
],
OutputType.STABLEHLO: [
"aten.amax",
Expand Down
Loading