diff --git a/include/circt/Dialect/HW/HWPasses.h b/include/circt/Dialect/HW/HWPasses.h index d693a237508a..a851ea31d868 100644 --- a/include/circt/Dialect/HW/HWPasses.h +++ b/include/circt/Dialect/HW/HWPasses.h @@ -34,6 +34,8 @@ std::unique_ptr createVerifyInnerRefNamespacePass(); std::unique_ptr createFlattenModulesPass(); std::unique_ptr createFooWiresPass(); std::unique_ptr createHWAggregateToCombPass(); +std::unique_ptr createHWExpungeModulePass(); +std::unique_ptr createHWTreeShakePass(); /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION diff --git a/include/circt/Dialect/HW/Passes.td b/include/circt/Dialect/HW/Passes.td index 270fb5c00bcd..0053e6413f1f 100644 --- a/include/circt/Dialect/HW/Passes.td +++ b/include/circt/Dialect/HW/Passes.td @@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> { let constructor = "circt::hw::createVerifyInnerRefNamespacePass()"; } +def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> { + let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level."; + let description = [{ + This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level. + The newly generated ports are by default named as __. During a naming conflict, an warning would be genreated, + and an random suffix would be added to the part. + + For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path. + }]; + let constructor = "circt::hw::createHWExpungeModulePass()"; + + let options = [ + ListOption<"modules", "modules", "std::string", + "Comma separated list of module names to be removed from the hierarchy.">, + ListOption<"portPrefixes", "port-prefixes", "std::string", + "Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as :=. Only affect the top-most level module of the instance path.">, + ]; +} + +def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> { + let summary = "Remove unused modules."; + let description = [{ + This pass removes all modules besides a specified list of modules and their transitive dependencies. + }]; + let constructor = "circt::hw::createHWTreeShakePass()"; + + let options = [ + ListOption<"keep", "keep", "std::string", + "Comma separated list of module names to be kept.">, + ]; +} + /** * Tutorial Pass, doesn't do anything interesting */ diff --git a/lib/Dialect/HW/Transforms/CMakeLists.txt b/lib/Dialect/HW/Transforms/CMakeLists.txt index f06b5ba5bb13..11d536cd757f 100644 --- a/lib/Dialect/HW/Transforms/CMakeLists.txt +++ b/lib/Dialect/HW/Transforms/CMakeLists.txt @@ -7,6 +7,8 @@ add_circt_dialect_library(CIRCTHWTransforms VerifyInnerRefNamespace.cpp FlattenModules.cpp FooWires.cpp + HWExpungeModule.cpp + HWTreeShake.cpp DEPENDS CIRCTHWTransformsIncGen diff --git a/lib/Dialect/HW/Transforms/HWExpungeModule.cpp b/lib/Dialect/HW/Transforms/HWExpungeModule.cpp new file mode 100644 index 000000000000..3351ff0e7277 --- /dev/null +++ b/lib/Dialect/HW/Transforms/HWExpungeModule.cpp @@ -0,0 +1,315 @@ +#include "circt/Dialect/HW/HWInstanceGraph.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ImmutableList.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Support/Regex.h" +#include + +namespace circt { +namespace hw { +#define GEN_PASS_DEF_HWEXPUNGEMODULE +#include "circt/Dialect/HW/Passes.h.inc" +} // namespace hw +} // namespace circt + +namespace { +struct HWExpungeModulePass + : circt::hw::impl::HWExpungeModuleBase { + void runOnOperation() override; +}; + +struct InstPathSeg { + llvm::StringRef seg; + + InstPathSeg(llvm::StringRef seg) : seg(seg) {} + const llvm::StringRef &getSeg() const { return seg; } + operator llvm::StringRef() const { return seg; } + + void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(seg); } +}; +using InstPath = llvm::ImmutableList; +std::string defaultPrefix(InstPath path) { + std::string accum; + while (!path.isEmpty()) { + accum += path.getHead().getSeg(); + accum += "_"; + path = path.getTail(); + } + accum += "_"; + return std::move(accum); +} + +// The regex for port prefix specification +// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$" +// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so +// manually parse the spec This parser may accept identifiers with invalid +// characters + +std::variant, + std::string> +parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) { + auto [l, r] = in.split("="); + if (r == "") + return "No '=' found in input"; + auto [ll, lr] = l.split(":"); + if (lr == "") + return "No ':' found before '='"; + llvm::SmallVector segs; + while (lr != "") { + auto [seg, rest] = lr.split("."); + segs.push_back(seg); + lr = rest; + } + InstPath path; + for (auto &seg : llvm::reverse(segs)) + path = listFac.add(seg, path); + return std::make_tuple(ll, path, r); +} +} // namespace + +void HWExpungeModulePass::runOnOperation() { + auto root = getOperation(); + llvm::DenseMap allModules; + root.walk( + [&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; }); + + // The instance graph. We only use this graph to traverse the hierarchy in + // post order. The order does not change throught out the operation, onlygets + // weakened, but still valid. So we keep this cached instance graph throughout + // the pass. + auto &instanceGraph = getAnalysis(); + + // Instance path. + InstPath::Factory pathFactory; + + // Process port prefix specifications + // (Module name, Instance path) -> Prefix + llvm::DenseMap, mlir::StringRef> + designatedPrefixes; + bool containsFailure = false; + for (const auto &raw : portPrefixes) { + auto matched = parsePrefixSpec(raw, pathFactory); + if (std::holds_alternative(matched)) { + llvm::errs() << "Invalid port prefix specification: " << raw << "\n"; + llvm::errs() << "Error: " << std::get(matched) << "\n"; + containsFailure = true; + continue; + } + + auto [module, path, prefix] = + std::get>( + matched); + if (!allModules.contains(module)) { + llvm::errs() << "Module not found in port prefix specification: " + << module << "\n"; + llvm::errs() << "From specification: " << raw << "\n"; + containsFailure = true; + continue; + } + + // Skip checking instance paths' existence. Non-existent paths are ignored + designatedPrefixes.insert({{module, path}, prefix}); + } + + if (containsFailure) + return signalPassFailure(); + + // Instance path * prefix name + using ReplacedDescendent = std::pair; + // This map holds the expunged descendents of a module + llvm::DenseMap> + expungedDescendents; + for (auto &expunging : this->modules) { + // Clear expungedDescendents + for (auto &it : expungedDescendents) + it.getSecond().clear(); + + auto expungingMod = allModules.lookup(expunging); + if (!expungingMod) + continue; // Ignored missing modules + auto expungingModTy = expungingMod.getHWModuleType(); + auto expungingModPorts = expungingModTy.getPorts(); + + auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod, + const std::string &prefix, + auto genOutput, auto emitInput) { + mlir::OpBuilder builder(mod); + // Create ports using *REVERSE* direction of their definitions + for (auto &port : expungingModPorts) { + auto defaultName = prefix + port.name.getValue(); + auto finalName = defaultName; + if (port.dir == circt::hw::PortInfo::Input) { + auto val = genOutput(port); + assert(val.getType() == port.type); + mod.appendOutput(finalName, val); + } else if (port.dir == circt::hw::PortInfo::Output) { + auto [_, arg] = mod.appendInput(finalName, port.type); + emitInput(port, arg); + } + } + }; + + for (auto &instGraphNode : llvm::post_order(&instanceGraph)) { + // Skip extmodule and intmodule because they cannot contain anything + circt::hw::HWModuleOp processing = + llvm::dyn_cast_if_present( + instGraphNode->getModule().getOperation()); + if (!processing) + continue; + + std::optional + outerExpDescHold = {}; + auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) { + if (!outerExpDescHold.has_value()) + outerExpDescHold = { + &expungedDescendents.insert({processing.getName(), {}}) + .first->getSecond()}; + return **outerExpDescHold; + }; + + mlir::OpBuilder outerBuilder(processing); + + processing.walk([&](circt::hw::InstanceOp inst) { + auto instName = inst.getInstanceName(); + auto instMod = allModules.lookup(inst.getModuleName()); + + if (instMod.getOutputNames().size() != inst.getResults().size() || + instMod.getNumInputPorts() != inst.getInputs().size()) { + // Module have been modified during this pass, create new instances + assert(instMod.getNumOutputPorts() >= inst.getResults().size()); + assert(instMod.getNumInputPorts() >= inst.getInputs().size()); + + auto instModInTypes = instMod.getInputTypes(); + + llvm::SmallVector newInputs; + newInputs.reserve(instMod.getNumInputPorts()); + + outerBuilder.setInsertionPointAfter(inst); + + // Appended inputs are at the end of the input list + for (size_t i = 0; i < instMod.getNumInputPorts(); ++i) { + mlir::Value input; + if (i < inst.getNumInputPorts()) { + input = inst.getInputs()[i]; + if (auto existingName = inst.getInputName(i)) + assert(existingName == instMod.getInputName(i)); + } else { + input = + outerBuilder + .create( + inst.getLoc(), instModInTypes[i], mlir::ValueRange{}) + .getResult(0); + } + newInputs.push_back(input); + } + + auto newInst = outerBuilder.create( + inst.getLoc(), instMod, inst.getInstanceNameAttr(), newInputs, + inst.getParameters(), + inst.getInnerSym().value_or({})); + + for (size_t i = 0; i < inst.getNumResults(); ++i) + assert(inst.getOutputName(i) == instMod.getOutputName(i)); + inst.replaceAllUsesWith( + newInst.getResults().slice(0, inst.getNumResults())); + inst.erase(); + inst = newInst; + } + + llvm::StringMap instOMap; + llvm::StringMap instIMap; + assert(instMod.getOutputNames().size() == inst.getResults().size()); + for (auto [oname, oval] : + llvm::zip(instMod.getOutputNames(), inst.getResults())) + instOMap[llvm::cast(oname).getValue()] = oval; + assert(instMod.getInputNames().size() == inst.getInputs().size()); + for (auto [iname, ival] : + llvm::zip(instMod.getInputNames(), inst.getInputs())) + instIMap[llvm::cast(iname).getValue()] = ival; + + // Get outer expunged descendent first because it may modify the map and + // invalidate iterators. + auto &outerExpDesc = getOuterExpDesc(); + auto instExpDesc = expungedDescendents.find(inst.getModuleName()); + + if (inst.getModuleName() == expunging) { + // Handle the directly expunged module + // input maps also useful for directly expunged instance + + auto singletonPath = pathFactory.create(instName); + + auto designatedPrefix = + designatedPrefixes.find({processing.getName(), singletonPath}); + std::string prefix = designatedPrefix != designatedPrefixes.end() + ? designatedPrefix->getSecond().str() + : (instName + "__").str(); + + // Port name collision is still possible, but current relying on MLIR + // to automatically rename input arguments. + // TODO: name collision detect + + createPortsOn( + processing, prefix, + [&](circt::hw::ModulePort port) { + // Generate output for outer module, so input for us + return instIMap.at(port.name); + }, + [&](circt::hw::ModulePort port, mlir::Value val) { + // Generated input for outer module, replace inst results + assert(instOMap.contains(port.name)); + instOMap[port.name].replaceAllUsesWith(val); + }); + + outerExpDesc.emplace_back(singletonPath, prefix); + + assert(instExpDesc == expungedDescendents.end() || + instExpDesc->getSecond().size() == 0); + inst.erase(); + } else if (instExpDesc != expungedDescendents.end()) { + // Handle all transitive descendents + if (instExpDesc->second.size() == 0) + return; + llvm::DenseMap newInputs; + for (const auto &exp : instExpDesc->second) { + auto newPath = pathFactory.add(instName, exp.first); + auto designatedPrefix = + designatedPrefixes.find({processing.getName(), newPath}); + std::string prefix = designatedPrefix != designatedPrefixes.end() + ? designatedPrefix->getSecond().str() + : defaultPrefix(newPath); + + // TODO: name collision detect + + createPortsOn( + processing, prefix, + [&](circt::hw::ModulePort port) { + // Generate output for outer module, directly forward from + // inner inst + return instOMap.at((exp.second + port.name.getValue()).str()); + }, + [&](circt::hw::ModulePort port, mlir::Value val) { + // Generated input for outer module, replace inst results. + // The operand in question has to be an backedge + auto in = + instIMap.at((exp.second + port.name.getValue()).str()); + auto inDef = in.getDefiningOp(); + assert(llvm::isa(inDef)); + in.replaceAllUsesWith(val); + inDef->erase(); + }); + + outerExpDesc.emplace_back(newPath, prefix); + } + } + }); + } + } +} + +std::unique_ptr circt::hw::createHWExpungeModulePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Dialect/HW/Transforms/HWTreeShake.cpp b/lib/Dialect/HW/Transforms/HWTreeShake.cpp new file mode 100644 index 000000000000..14473e78ce71 --- /dev/null +++ b/lib/Dialect/HW/Transforms/HWTreeShake.cpp @@ -0,0 +1,53 @@ +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Regex.h" + +namespace circt { +namespace hw { +#define GEN_PASS_DEF_HWTREESHAKE +#include "circt/Dialect/HW/Passes.h.inc" +} // namespace hw +} // namespace circt + +struct HWTreeShakePass : circt::hw::impl::HWTreeShakeBase { + void runOnOperation() override; +}; + +void HWTreeShakePass::runOnOperation() { + auto root = getOperation(); + llvm::DenseMap allModules; + root.walk( + [&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; }); + + llvm::DenseSet visited; + auto visit = [&allModules, &visited](auto &self, + circt::hw::HWModuleLike mod) -> void { + if (visited.contains(mod)) + return; + visited.insert(mod); + mod.walk([&](circt::hw::InstanceOp inst) { + auto modName = inst.getModuleName(); + self(self, allModules.at(modName)); + }); + }; + + for (const auto &kept : keep) { + auto lookup = allModules.find(kept); + if (lookup == allModules.end()) + continue; // Silently ignore missing modules + visit(visit, lookup->getSecond()); + } + + for (auto &mod : allModules) { + if (!visited.contains(mod.getSecond())) { + mod.getSecond()->remove(); + } + } +} + +std::unique_ptr circt::hw::createHWTreeShakePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/test/Dialect/HW/expunge-module.mlir b/test/Dialect/HW/expunge-module.mlir new file mode 100644 index 000000000000..11ea30608082 --- /dev/null +++ b/test/Dialect/HW/expunge-module.mlir @@ -0,0 +1,34 @@ +// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}})" %s | FileCheck %s --check-prefixes FOO,BAR,BAZ,COMMON +// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=foo})" %s | FileCheck %s --check-prefixes FOO,BAR,COMMON +// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=baz})" %s | FileCheck %s --check-prefixes BAZ,COMMON + +module { + hw.module @foo(in %bar1_baz1__out : i2, out test : i1) { + %bar1.self_out = hw.instance "bar1" @bar(self_in: %0: i1) -> (self_out: i1) + %bar2.self_out = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1) -> (self_out: i1) + %0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1 + hw.output %bar2.self_out : i1 + } + hw.module private @bar(in %self_in : i1, out self_out : i1) { + %baz1.out = hw.instance "baz1" @baz(in: %self_in: i1) -> (out: i1) + %baz2.out = hw.instance "baz2" @baz(in: %baz1.out: i1) -> (out: i1) + hw.output %baz2.out : i1 + } + hw.module private @baz(in %in : i1, out out : i1) { + hw.output %in : i1 + } +} + +// COMMON: module { +// FOO-NEXT: hw.module @foo(in %bar1_baz1__out : i2, in %bar1_baz1__out_0 : i1, in %bar1_baz2__out : i1, in %bar2_baz1__out : i1, in %meow_out : i1, out test : i1, out bar1_baz1__in : i1, out bar1_baz2__in : i1, out bar2_baz1__in : i1, out meow_in : i1) { +// FOO-NEXT: %bar1.self_out, %bar1.nya_in, %bar1.baz2__in = hw.instance "bar1" @bar(self_in: %0: i1, nya_out: %bar1_baz1__out_0: i1, baz2__out: %bar1_baz2__out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1) +// FOO-NEXT: %bar2.self_out, %bar2.nya_in, %bar2.baz2__in = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1, nya_out: %bar2_baz1__out: i1, baz2__out: %meow_out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1) +// FOO-NEXT: %0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1 +// FOO-NEXT: hw.output %bar2.self_out, %bar1.nya_in, %bar1.baz2__in, %bar2.nya_in, %bar2.baz2__in : i1, i1, i1, i1, i1 +// FOO-NEXT: } +// BAR-NEXT: hw.module private @bar(in %self_in : i1, in %nya_out : i1, in %baz2__out : i1, out self_out : i1, out nya_in : i1, out baz2__in : i1) { +// BAR-NEXT: hw.output %baz2__out, %self_in, %nya_out : i1, i1, i1 +// BAR-NEXT: } +// BAZ-NEXT: hw.module private @baz(in %in : i1, out out : i1) { +// BAZ-NEXT: hw.output %in : i1 +// BAZ-NEXT: }