@@ -413,4 +413,37 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
413
413
414
414
return true;
415
415
}
416
+ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
417
+ --- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
418
+ +++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
419
+ @@ -23,6 +23,7 @@
420
+ #include "llvm/ADT/STLExtras.h"
421
+ #include "llvm/ADT/SmallVector.h"
422
+ #include "llvm/ADT/StringRef.h"
423
+ + #include "llvm/Support/Debug.h"
424
+ #include "llvm/Support/ErrorHandling.h"
425
+ #include "mlir/Dialect/CommonFolders.h"
426
+ #include "mlir/Dialect/Func/IR/FuncOps.h"
427
+ @@ -47,6 +48,8 @@
428
+ #include "stablehlo/dialect/Base.h"
429
+ #include "stablehlo/dialect/StablehloOps.h"
430
+ #include "stablehlo/transforms/optimization/Passes.h"
431
+ +
432
+ + #define DEBUG_TYPE "stablehlo-optimization"
433
+
434
+ namespace mlir {
435
+ namespace stablehlo {
436
+ @@ -779,7 +782,12 @@
437
+ using OpRewritePattern::OpRewritePattern;
438
+ LogicalResult matchAndRewrite(IotaOp op,
439
+ PatternRewriter& rewriter) const override {
440
+ + LLVM_DEBUG(llvm::dbgs() << "EvalIotaOpPattern folding: " << op << '\n');
441
+ auto resultType = cast<RankedTensorType>(op.getType());
442
+ + size_t numElems = resultType.getNumElements();
443
+ + if (numElems > kFoldOpEltLimit)
444
+ + return rewriter.notifyMatchFailure(op, "too many elements to fold");
445
+ +
446
+ auto elementType = resultType.getElementType();
447
+
448
+ if (!elementType.isInteger())
416
449
0 commit comments