@@ -1132,6 +1132,33 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
1132
1132
};
1133
1133
} // namespace
1134
1134
1135
+ namespace {
1136
+ class ConvertAtenDropoutOp : public OpConversionPattern <AtenDropoutOp> {
1137
+ public:
1138
+ using OpConversionPattern::OpConversionPattern;
1139
+ LogicalResult
1140
+ matchAndRewrite (AtenDropoutOp op, OpAdaptor adaptor,
1141
+ ConversionPatternRewriter &rewriter) const override {
1142
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1143
+ return failure ();
1144
+
1145
+ bool train;
1146
+ if (!matchPattern (op.train (), m_TorchConstantBool (&train)))
1147
+ return rewriter.notifyMatchFailure (op,
1148
+ " Expected train to be constant bool." );
1149
+
1150
+ if (train)
1151
+ return failure ();
1152
+ auto resultType = getTypeConverter ()
1153
+ ->convertType (op->getResult (0 ).getType ())
1154
+ .cast <RankedTensorType>();
1155
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType,
1156
+ adaptor.input ());
1157
+ return success ();
1158
+ }
1159
+ };
1160
+ } // namespace
1161
+
1135
1162
namespace {
1136
1163
// See comments at in convertMmOp and the heading for this section for general
1137
1164
// considerations. This function needs to be auto-generated.
@@ -3035,6 +3062,8 @@ class ConvertTorchToLinalg
3035
3062
patterns.add <ConvertAtenIntTensorOp>(typeConverter, context);
3036
3063
target.addIllegalOp <PrimNumToTensorScalarOp>();
3037
3064
patterns.add <ConvertPrimNumToTensorScalarOp>(typeConverter, context);
3065
+ target.addIllegalOp <AtenDropoutOp>();
3066
+ patterns.add <ConvertAtenDropoutOp>(typeConverter, context);
3038
3067
3039
3068
if (failed (applyPartialConversion (getOperation (), target,
3040
3069
std::move (patterns))))
0 commit comments