File tree 1 file changed +4
-2
lines changed
lib/Conversion/TorchToStablehlo
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
569
569
ConversionPatternRewriter &rewriter) const {
570
570
Value input = adaptor.getSelf ();
571
571
auto inputTy = input.getType ().cast <RankedTensorType>();
572
+ auto outTy =
573
+ getTypeConverter ()->convertType (op.getType ()).cast <RankedTensorType>();
574
+ input = hlo::promoteType (rewriter, op.getLoc (), input, outTy);
575
+ inputTy = input.getType ().cast <RankedTensorType>();
572
576
auto inputElemTy = inputTy.getElementType ();
573
577
auto inputRank = inputTy.getRank ();
574
578
auto inputShape = inputTy.getShape ();
575
- auto outTy =
576
- getTypeConverter ()->convertType (op.getType ()).cast <RankedTensorType>();
577
579
578
580
int64_t dim;
579
581
if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim))) {
You can’t perform that action at this time.
0 commit comments