Skip to content

Commit 539511c

Browse files
authored
Add dropout op (#436)
Co-authored-by: dan <[email protected]>
1 parent 03fdf56 commit 539511c

File tree

5 files changed

+67
-1
lines changed

5 files changed

+67
-1
lines changed

e2e_testing/torchscript/basic.py

+20
Original file line numberDiff line numberDiff line change
@@ -738,3 +738,23 @@ def forward(self, input, tensor1, tensor2):
738738
@register_test_case(module_factory=lambda: AddCDivModule())
739739
def AddCDivModule_basic(module, tu: TestUtils):
740740
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
741+
742+
743+
# ==============================================================================
744+
745+
class DropoutModule(torch.nn.Module):
746+
def __init__(self):
747+
super().__init__()
748+
749+
@export
750+
@annotate_args([
751+
None,
752+
([-1, -1], torch.float32, True),
753+
])
754+
def forward(self, x):
755+
return torch.dropout(x, 0.0, False)
756+
757+
758+
@register_test_case(module_factory=lambda: DropoutModule())
759+
def DropoutModule_basic(module, tu: TestUtils):
760+
module.forward(tu.rand(3, 4))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

+16
Original file line numberDiff line numberDiff line change
@@ -2199,6 +2199,22 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
21992199
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
22002200
}
22012201

2202+
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
2203+
AllowsTypeRefinement,
2204+
HasValueSemantics
2205+
]> {
2206+
let summary = "Generated op for `aten::dropout : (Tensor, float, bool) -> (Tensor)`";
2207+
let arguments = (ins
2208+
AnyTorchTensorType:$input,
2209+
Torch_FloatType:$p,
2210+
Torch_BoolType:$train
2211+
);
2212+
let results = (outs
2213+
AnyTorchTensorType:$result
2214+
);
2215+
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
2216+
}
2217+
22022218
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
22032219
AllowsTypeRefinement,
22042220
HasValueSemantics

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,33 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
11321132
};
11331133
} // namespace
11341134

1135+
namespace {
1136+
class ConvertAtenDropoutOp : public OpConversionPattern<AtenDropoutOp> {
1137+
public:
1138+
using OpConversionPattern::OpConversionPattern;
1139+
LogicalResult
1140+
matchAndRewrite(AtenDropoutOp op, OpAdaptor adaptor,
1141+
ConversionPatternRewriter &rewriter) const override {
1142+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1143+
return failure();
1144+
1145+
bool train;
1146+
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
1147+
return rewriter.notifyMatchFailure(op,
1148+
"Expected train to be constant bool.");
1149+
1150+
if (train)
1151+
return failure();
1152+
auto resultType = getTypeConverter()
1153+
->convertType(op->getResult(0).getType())
1154+
.cast<RankedTensorType>();
1155+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
1156+
adaptor.input());
1157+
return success();
1158+
}
1159+
};
1160+
} // namespace
1161+
11351162
namespace {
11361163
// See comments at in convertMmOp and the heading for this section for general
11371164
// considerations. This function needs to be auto-generated.
@@ -3035,6 +3062,8 @@ class ConvertTorchToLinalg
30353062
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
30363063
target.addIllegalOp<PrimNumToTensorScalarOp>();
30373064
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
3065+
target.addIllegalOp<AtenDropoutOp>();
3066+
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
30383067

30393068
if (failed(applyPartialConversion(getOperation(), target,
30403069
std::move(patterns))))

lib/Dialect/Torch/Transforms/RefineTypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
231231
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
232232
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
233233
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
234-
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
234+
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
235235
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>(
236236
op)) {
237237
return getLatticeElement(op->getResult(0)).join(*operands[0]);

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def emit_with_mutating_variants(key, **kwargs):
569569
emit("aten::IntImplicit : (Tensor) -> (int)")
570570
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
571571
emit("aten::Int.Tensor : (Tensor) -> (int)")
572+
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
572573

573574
# Dict ops.
574575
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)

0 commit comments

Comments
 (0)