Skip to content

Commit 6188869

Browse files
authored
[onnx] Add support for onnx.sinh (#2643)
Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding the `aten.sinh` operator.
1 parent b3e9420 commit 6188869

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+45
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [
526526
}];
527527
}
528528

529+
def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [
530+
AllowsTypeRefinement,
531+
HasValueSemantics,
532+
ReadOnly
533+
]> {
534+
let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`";
535+
let arguments = (ins
536+
AnyTorchTensorType:$self
537+
);
538+
let results = (outs
539+
AnyTorchTensorType:$result
540+
);
541+
let hasCustomAssemblyFormat = 1;
542+
let extraClassDefinition = [{
543+
ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) {
544+
return parseDefaultTorchOp(parser, result, 1, 1);
545+
}
546+
void AtenSinhOp::print(OpAsmPrinter &printer) {
547+
printDefaultTorchOp(printer, *this, 1, 1);
548+
}
549+
}];
550+
}
551+
552+
def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [
553+
IsTrailingUnderscoreInplaceVariant,
554+
AllowsTypeRefinement
555+
]> {
556+
let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`";
557+
let arguments = (ins
558+
Torch_NonValueTensorType:$self
559+
);
560+
let results = (outs
561+
Torch_NonValueTensorType:$result
562+
);
563+
let hasCustomAssemblyFormat = 1;
564+
let extraClassDefinition = [{
565+
ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) {
566+
return parseDefaultTorchOp(parser, result, 1, 1);
567+
}
568+
void AtenSinh_Op::print(OpAsmPrinter &printer) {
569+
printDefaultTorchOp(printer, *this, 1, 1);
570+
}
571+
}];
572+
}
573+
529574
def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [
530575
AllowsTypeRefinement,
531576
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
467467
if (binder.tensorOperand(operand) ||
468468
binder.tensorResultType(resultType))
469469
return failure();
470-
471470
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
472471
binder.op, resultType, operand);
473472
return success();
474473
});
474+
475+
patterns.onOp("Sinh", 9,
476+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
477+
Torch::ValueTensorType resultType;
478+
Value operand;
479+
if (binder.tensorOperand(operand) ||
480+
binder.tensorResultType(resultType))
481+
return failure();
482+
483+
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
484+
binder.op, resultType, operand);
485+
return success();
486+
});
475487

476488
patterns.onOp(
477489
"Transpose", 13,

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def emit_with_mutating_variants(key, **kwargs):
266266
"aten::selu : (Tensor) -> (Tensor)",
267267
"aten::sigmoid : (Tensor) -> (Tensor)",
268268
"aten::sign : (Tensor) -> (Tensor)",
269+
"aten::sinh : (Tensor) -> (Tensor)",
269270
"aten::sgn : (Tensor) -> (Tensor)",
270271
"aten::hardsigmoid : (Tensor) -> (Tensor)",
271272
"aten::hardswish : (Tensor) -> (Tensor)",

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,15 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
489489

490490
// -----
491491

492+
// CHECK-LABEL: func.func @test_sinh
493+
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
494+
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
495+
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
496+
return %0 : !torch.vtensor<[3],f32>
497+
}
498+
499+
// -----
500+
492501
// CHECK-LABEL: func.func @test_transpose_default
493502
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
494503
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0

0 commit comments

Comments
 (0)