Skip to content

Commit aa50a10

Browse files
abhigunjGoogle-ML-Automation
authored andcommitted
StableHLOAggressiveFolderPass : Don't fold iota op if number of elements is large (> 65536)
PiperOrigin-RevId: 735567318
1 parent b418287 commit aa50a10

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

third_party/stablehlo/temporary.patch

+33
Original file line numberDiff line numberDiff line change
@@ -413,4 +413,37 @@ diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stable
413413

414414
return true;
415415
}
416+
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
417+
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
418+
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp
419+
@@ -23,6 +23,7 @@
420+
#include "llvm/ADT/STLExtras.h"
421+
#include "llvm/ADT/SmallVector.h"
422+
#include "llvm/ADT/StringRef.h"
423+
+#include "llvm/Support/Debug.h"
424+
#include "llvm/Support/ErrorHandling.h"
425+
#include "mlir/Dialect/CommonFolders.h"
426+
#include "mlir/Dialect/Func/IR/FuncOps.h"
427+
@@ -47,6 +48,8 @@
428+
#include "stablehlo/dialect/Base.h"
429+
#include "stablehlo/dialect/StablehloOps.h"
430+
#include "stablehlo/transforms/optimization/Passes.h"
431+
+
432+
+#define DEBUG_TYPE "stablehlo-optimization"
433+
434+
namespace mlir {
435+
namespace stablehlo {
436+
@@ -779,7 +782,12 @@
437+
using OpRewritePattern::OpRewritePattern;
438+
LogicalResult matchAndRewrite(IotaOp op,
439+
PatternRewriter& rewriter) const override {
440+
+ LLVM_DEBUG(llvm::dbgs() << "EvalIotaOpPattern folding: " << op << '\n');
441+
auto resultType = cast<RankedTensorType>(op.getType());
442+
+ size_t numElems = resultType.getNumElements();
443+
+ if (numElems > kFoldOpEltLimit)
444+
+ return rewriter.notifyMatchFailure(op, "too many elements to fold");
445+
+
446+
auto elementType = resultType.getElementType();
447+
448+
if (!elementType.isInteger())
416449

0 commit comments

Comments
 (0)