diff --git a/tests/unit/compiler/venom/test_loop_invariant_hoisting.py b/tests/unit/compiler/venom/test_loop_invariant_hoisting.py new file mode 100644 index 0000000000..006fc15073 --- /dev/null +++ b/tests/unit/compiler/venom/test_loop_invariant_hoisting.py @@ -0,0 +1,217 @@ +import pytest + +from tests.venom_utils import parse_from_basic_block, assert_ctx_eq +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.loop_detection import NaturalLoopDetectionAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable +from vyper.venom.context import IRContext +from vyper.venom.function import IRFunction +from vyper.venom.passes.loop_invariant_hosting import LoopInvariantHoisting + + +def _create_loops(fn, depth, loop_id, body_fn=lambda _: (), top=True): + bb = fn.get_basic_block() + cond = IRBasicBlock(IRLabel(f"cond{loop_id}{depth}"), fn) + body = IRBasicBlock(IRLabel(f"body{loop_id}{depth}"), fn) + if top: + exit_block = IRBasicBlock(IRLabel(f"exit_top{loop_id}{depth}"), fn) + else: + exit_block = IRBasicBlock(IRLabel(f"exit{loop_id}{depth}"), fn) + fn.append_basic_block(cond) + fn.append_basic_block(body) + + bb.append_instruction("jmp", cond.label) + + cond_var = IRVariable(f"cond_var{loop_id}{depth}") + cond.append_instruction("iszero", 0, ret=cond_var) + assert isinstance(cond_var, IRVariable) + cond.append_instruction("jnz", cond_var, body.label, exit_block.label) + body_fn(fn, loop_id, depth) + if depth > 1: + _create_loops(fn, depth - 1, loop_id, body_fn, top=False) + bb = fn.get_basic_block() + bb.append_instruction("jmp", cond.label) + fn.append_basic_block(exit_block) + + +def _simple_body(fn, loop_id, depth): + assert isinstance(fn, IRFunction) + bb = fn.get_basic_block() + add_var = IRVariable(f"add_var{loop_id}{depth}") + bb.append_instruction("add", 1, 2, ret=add_var) + + +def _hoistable_body(fn, loop_id, depth): + assert isinstance(fn, IRFunction) + bb = fn.get_basic_block() + store_var = IRVariable(f"store_var{loop_id}{depth}") + add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") + bb.append_instruction("store", 1, ret=store_var) + bb.append_instruction("add", 1, store_var, ret=add_var_a) + add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") + bb.append_instruction("add", store_var, add_var_a, ret=add_var_b) + + +def _simple_body_code(loop_id, depth): + return f""" + %a{depth}{loop_id} = add 1, 2""" + + +def _create_loops_code(depth, loop_id, body=lambda _, _a: "", last: bool = False): + if depth <= 0: + return "" + inner = _create_loops_code(depth - 1, loop_id, body, False) + + res = f""" + jmp @cond{depth}{loop_id} + cond{depth}{loop_id}: + jnz %par, @exit{depth}{loop_id}, @body{depth}{loop_id} + body{depth}{loop_id}: + {body(loop_id, depth)} + {inner} + jmp @cond{depth}{loop_id} + exit{depth}{loop_id}: + """ + + if last: + res += """ + stop + """ + + return res + + +@pytest.mark.parametrize("depth", range(1, 4)) +@pytest.mark.parametrize("count", range(1, 4)) +def test_loop_detection_analysis(depth, count): + loops = "" + for i in range(count): + loops += _create_loops_code(depth, i, _simple_body_code, last=(i == count - 1)) + + code = f""" + main: + %par = param + {loops} + """ + + print(code) + + ctx = parse_from_basic_block(code) + assert len(ctx.functions) == 1 + + fn = list(ctx.functions.values())[0] + ac = IRAnalysesCache(fn) + analysis = ac.request_analysis(NaturalLoopDetectionAnalysis) + assert isinstance(analysis, NaturalLoopDetectionAnalysis) + + assert len(analysis.loops) == depth * count + + +@pytest.mark.parametrize("depth", range(1, 4)) +@pytest.mark.parametrize("count", range(1, 4)) +def test_loop_invariant_hoisting_simple(depth, count): + pre_loops = "" + for i in range(count): + pre_loops += _create_loops_code(depth, i, _simple_body_code, last=(i == count - 1)) + + post_loops = "" + for i in range(count): + hoisted = "" + for d in range(depth): + hoisted += _simple_body_code(i, depth - d) + post_loops += hoisted + post_loops += _create_loops_code(depth, i, last=(i == count - 1)) + + + pre = f""" + main: + %par = param + {pre_loops} + """ + + post = f""" + main: + %par = param + {post_loops} + """ + + ctx = parse_from_basic_block(pre) + print(ctx) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + LoopInvariantHoisting(ac, fn).run_pass() + + post_ctx = parse_from_basic_block(post) + + print(ctx) + print(post_ctx) + + assert_ctx_eq(ctx, post_ctx) + +@pytest.mark.parametrize("depth", range(1, 4)) +@pytest.mark.parametrize("count", range(1, 4)) +def test_loop_invariant_hoisting_dependant(depth, count): + ctx = IRContext() + fn = ctx.create_function("_global") + + for c in range(count): + _create_loops(fn, depth, c, _hoistable_body) + + bb = fn.get_basic_block() + bb.append_instruction("ret") + + ac = IRAnalysesCache(fn) + LoopInvariantHoisting(ac, fn).run_pass() + + entry = fn.entry + assignments = list(map(lambda x: x.value, entry.get_assignments())) + for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): + assignments.extend(map(lambda x: x.value, bb.get_assignments())) + + assert len(assignments) == depth * count * 4 + for loop_id in range(count): + for d in range(1, depth + 1): + assert f"%store_var{loop_id}{d}" in assignments, repr(fn) + assert f"%add_var_a{loop_id}{d}" in assignments, repr(fn) + assert f"%add_var_b{loop_id}{d}" in assignments, repr(fn) + assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) + + +def _unhoistable_body(fn, loop_id, depth): + assert isinstance(fn, IRFunction) + bb = fn.get_basic_block() + add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") + bb.append_instruction("mload", 64, ret=add_var_a) + add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") + bb.append_instruction("add", add_var_a, 2, ret=add_var_b) + bb.append_instruction("mstore", 10, add_var_b) + + +@pytest.mark.parametrize("depth", range(1, 4)) +@pytest.mark.parametrize("count", range(1, 4)) +def test_loop_invariant_hoisting_unhoistable(depth, count): + ctx = IRContext() + fn = ctx.create_function("_global") + + for c in range(count): + _create_loops(fn, depth, c, _unhoistable_body) + + bb = fn.get_basic_block() + bb.append_instruction("ret") + + print(fn) + + ac = IRAnalysesCache(fn) + LoopInvariantHoisting(ac, fn).run_pass() + print(fn) + + entry = fn.entry + assignments = list(map(lambda x: x.value, entry.get_assignments())) + for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): + assignments.extend(map(lambda x: x.value, bb.get_assignments())) + + assert len(assignments) == depth * count + for loop_id in range(count): + for d in range(1, depth + 1): + assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index af9a39683e..017a9c3126 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -27,6 +27,7 @@ StoreElimination, StoreExpansionPass, ) +from vyper.venom.passes.loop_invariant_hosting import LoopInvariantHoisting from vyper.venom.venom_to_assembly import VenomCompiler DEFAULT_OPT_LEVEL = OptimizationLevel.default() @@ -84,6 +85,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache AlgebraicOptimizationPass(ac, fn).run_pass() RemoveUnusedVariablesPass(ac, fn).run_pass() + LoopInvariantHoisting(ac, fn).run_pass() StoreExpansionPass(ac, fn).run_pass() diff --git a/vyper/venom/analysis/loop_detection.py b/vyper/venom/analysis/loop_detection.py new file mode 100644 index 0000000000..e2cf01084b --- /dev/null +++ b/vyper/venom/analysis/loop_detection.py @@ -0,0 +1,69 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysis +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.basicblock import IRBasicBlock + + +class NaturalLoopDetectionAnalysis(IRAnalysis): + """ + Detects loops and computes basic blocks + and the block which is before the loop + """ + + # key = loop header + # value = all the blocks that the loop contains + loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + + def analyze(self): + self.analyses_cache.request_analysis(CFGAnalysis) + self.loops = self._find_natural_loops(self.function.entry) + + # Could possibly reuse the dominator tree algorithm to find the back edges + # if it is already cached it will be faster. Still might need to separate the + # varius extra information that the dominator analysis provides + # (like frontiers and immediate dominators) + def _find_back_edges(self, entry: IRBasicBlock) -> list[tuple[IRBasicBlock, IRBasicBlock]]: + back_edges = [] + visited: OrderedSet[IRBasicBlock] = OrderedSet() + stack = [] + + def dfs(bb: IRBasicBlock): + visited.add(bb) + stack.append(bb) + + for succ in bb.cfg_out: + if succ not in visited: + dfs(succ) + elif succ in stack: + back_edges.append((bb, succ)) + + stack.pop() + + dfs(entry) + + return back_edges + + def _find_natural_loops( + self, entry: IRBasicBlock + ) -> dict[IRBasicBlock, OrderedSet[IRBasicBlock]]: + back_edges = self._find_back_edges(entry) + natural_loops = {} + + for u, v in back_edges: + # back edge: u -> v + loop: OrderedSet[IRBasicBlock] = OrderedSet() + stack = [u] + + while stack: + bb = stack.pop() + if bb in loop: + continue + loop.add(bb) + for pred in bb.cfg_in: + if pred != v: + stack.append(pred) + + loop.add(v) + natural_loops[v.cfg_in.first()] = loop + + return natural_loops diff --git a/vyper/venom/passes/loop_invariant_hosting.py b/vyper/venom/passes/loop_invariant_hosting.py new file mode 100644 index 0000000000..0479f2af95 --- /dev/null +++ b/vyper/venom/passes/loop_invariant_hosting.py @@ -0,0 +1,138 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.analysis.loop_detection import NaturalLoopDetectionAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable +from vyper.venom.effects import EMPTY, Effects +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +def _ignore_instruction(inst: IRInstruction) -> bool: + return ( + inst.is_bb_terminator + or inst.opcode == "returndatasize" + or inst.opcode == "phi" + or (inst.opcode == "add" and isinstance(inst.operands[1], IRLabel)) + or inst.opcode == "store" + ) + + +# must check if it has as operand as literal because +# there are cases when the store just moves value +# from one variable to another +def _is_correct_store(inst: IRInstruction) -> bool: + return inst.opcode == "store" and isinstance(inst.operands[0], IRLiteral) + + +class LoopInvariantHoisting(IRPass): + """ + This pass detects invariants in loops and hoists them above the loop body. + Any VOLATILE_INSTRUCTIONS, BB_TERMINATORS CFG_ALTERING_INSTRUCTIONS are ignored + """ + + function: IRFunction + loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + dfg: DFGAnalysis + + def run_pass(self): + self.analyses_cache.request_analysis(CFGAnalysis) + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore + loops = self.analyses_cache.request_analysis(NaturalLoopDetectionAnalysis) + assert isinstance(loops, NaturalLoopDetectionAnalysis) + self.loops = loops.loops + invalidate = False + while True: + change = False + for from_bb, loop in self.loops.items(): + hoistable: list[IRInstruction] = self._get_hoistable_loop(from_bb, loop) + if len(hoistable) == 0: + continue + change |= True + self._hoist(from_bb, hoistable) + if not change: + break + invalidate = True + + # only need to invalidate if you did some hoisting + if invalidate: + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + + def _hoist(self, target_bb: IRBasicBlock, hoistable: list[IRInstruction]): + for inst in hoistable: + bb = inst.parent + bb.remove_instruction(inst) + target_bb.insert_instruction(inst, index=len(target_bb.instructions) - 1) + + def _get_loop_effects_write(self, loop: OrderedSet[IRBasicBlock]) -> Effects: + res: Effects = EMPTY + for bb in loop: + assert isinstance(bb, IRBasicBlock) # help mypy + for inst in bb.instructions: + res |= inst.get_write_effects() + return res + + def _get_hoistable_loop( + self, from_bb: IRBasicBlock, loop: OrderedSet[IRBasicBlock] + ) -> list[IRInstruction]: + result: list[IRInstruction] = [] + loop_effects = self._get_loop_effects_write(loop) + for bb in loop: + result.extend(self._get_hoistable_bb(bb, from_bb, loop_effects)) + return result + + def _get_hoistable_bb( + self, bb: IRBasicBlock, loop_idx: IRBasicBlock, loop_effects: Effects + ) -> list[IRInstruction]: + result: list[IRInstruction] = [] + for inst in bb.instructions: + if self._can_hoist_instruction_ignore_stores(inst, self.loops[loop_idx], loop_effects): + result.extend(self._store_dependencies(inst, loop_idx)) + result.append(inst) + + return result + + # query store dependacies of instruction (they are not handled otherwise) + def _store_dependencies( + self, inst: IRInstruction, loop_idx: IRBasicBlock + ) -> list[IRInstruction]: + result: list[IRInstruction] = [] + for var in inst.get_input_variables(): + source_inst = self.dfg.get_producing_instruction(var) + assert isinstance(source_inst, IRInstruction) + if not _is_correct_store(source_inst): + continue + for bb in self.loops[loop_idx]: + if source_inst.parent == bb: + result.append(source_inst) + return result + + # since the stores are always hoistable this ignores + # stores in analysis (their are hoisted if some instrution is dependent on them) + def _can_hoist_instruction_ignore_stores( + self, inst: IRInstruction, loop: OrderedSet[IRBasicBlock], loop_effects: Effects + ) -> bool: + if (inst.get_read_effects() & loop_effects) != EMPTY: + return False + if _ignore_instruction(inst): + return False + for bb in loop: + if self._dependent_in_bb(inst, bb): + return False + return True + + def _dependent_in_bb(self, inst: IRInstruction, bb: IRBasicBlock): + for in_var in inst.get_input_variables(): + assert isinstance(in_var, IRVariable) + source_ins = self.dfg.get_producing_instruction(in_var) + assert isinstance(source_ins, IRInstruction) + + # ignores stores since all stores are independant + # and can be always hoisted + if _is_correct_store(source_ins): + continue + + if source_ins.parent == bb: + return True + return False