From 74003cee37968c79d8f88edaf660c328dc55c18f Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Mon, 14 Oct 2024 21:26:18 -0700 Subject: [PATCH] Add test implementation of computation graph cache --- Makefile | 2 +- coconut/compiler/compiler.py | 12 +++++-- coconut/compiler/grammar.py | 2 -- coconut/compiler/util.py | 67 +++++++++++++++++++++++++++++++++--- coconut/constants.py | 2 +- coconut/terminal.py | 16 ++++++--- 6 files changed, 86 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index f60e01a9..a1959766 100644 --- a/Makefile +++ b/Makefile @@ -170,7 +170,7 @@ test-pyright: clean python ./coconut/tests/dest/extras.py # same as test-univ but includes verbose output for better debugging -# regex for getting non-timing lines: ^(?!'|\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned|Compiled)\s)[^\n]*\n* +# regex for getting non-informational lines to delete: ^(?!'|\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned|Compiled|Computation)\s)[^\n]*\n* .PHONY: test-verbose test-verbose: export COCONUT_USE_COLOR=TRUE test-verbose: clean diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 54a92092..79975f5f 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -156,6 +156,7 @@ match_in, transform, parse, + cached_parse, get_target_info_smart, split_leading_comments, compile_regex, @@ -405,6 +406,7 @@ class Compiler(Grammar, pickleable_obj): """The Coconut compiler.""" lock = Lock() current_compiler = None + computation_graph_caches = defaultdict(dict) preprocs = [ lambda self: self.prepare, @@ -1372,14 +1374,20 @@ def parse_line_by_line(self, init_parser, line_parser, original): while cur_loc < len(original): self.remaining_original = original[cur_loc:] ComputationNode.add_to_loc = cur_loc - results = parse(init_parser if init else line_parser, self.remaining_original, inner=False) + parser = init_parser if init else line_parser + results = cached_parse( + self.computation_graph_caches[("line_by_line", parser)], + parser, + self.remaining_original, + inner=False, + ) if len(results) == 1: got_loc, = results else: got, got_loc = results out_parts.append(got) got_loc = int(got_loc) - internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) + internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, got_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) cur_loc = got_loc init = False return "".join(out_parts) diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index f7947e9e..97cf5716 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -2781,8 +2781,6 @@ class Grammar(object): rparen, ) + end_marker, tco_return_handle, - # this is the root in what it's used for, so might as well evaluate greedily - greedy=True, )) rest_of_lambda = Forward() diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index d733a76b..962d89a1 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -199,6 +199,7 @@ def evaluate_tokens(tokens, **kwargs): if not USE_COMPUTATION_GRAPH: return tokens + final_evaluate_tokens.enabled = True # special variable used by cached_parse if isinstance(tokens, ParseResults): @@ -374,6 +375,8 @@ def evaluate(self): # note that this should never cache, since if a greedy Wrap that doesn't add to the packrat context # hits the cache, it'll get the same ComputationNode object, but since it's greedy that object needs # to actually be reevaluated + if logger.tracing and not final_evaluate_tokens.enabled: + logger.log_tag("cached_parse invalidated by", self) evaluated_toks = evaluate_tokens(self.tokens) if logger.tracing: # avoid the overhead of the call if not tracing logger.log_trace(self.name, self.original, self.loc, evaluated_toks, self.tokens) @@ -523,12 +526,17 @@ def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_ar def final_evaluate_tokens(tokens): """Same as evaluate_tokens but should only be used once a parse is assured.""" + if not final_evaluate_tokens.enabled: # handled by cached_parse + return tokens result = evaluate_tokens(tokens, is_final=True) # clear packrat cache after evaluating tokens so error creation gets to see the cache clear_packrat_cache() return result +final_evaluate_tokens.enabled = True + + def final(item): """Collapse the computation graph upon parsing the given item.""" # evaluate_tokens expects a computation graph, so we just call add_action directly @@ -674,17 +682,66 @@ def parse(grammar, text, inner=None, eval_parse_tree=True): return result -def try_parse(grammar, text, inner=None, eval_parse_tree=True): +def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_tree=True): + """Version of parse that caches the result when it's a pure ComputationNode.""" + if not CPYPARSING: # caching is only supported on cPyparsing + return parse(grammar, text, inner) + + for (prefix, is_at_end), tokens in computation_graph_cache.items(): + # the assumption here is that if the prior parse didn't make it to the end, + # then we can freely change the text after the end of where it made it, + # but if it did make it to the end, then we can't add more text after that + if ( + is_at_end and text == prefix + or not is_at_end and text.startswith(prefix) + ): + if DEVELOP: + logger.record_stat("cached_parse", True) + logger.log_tag("cached_parse hit", (prefix, text[len(prefix):], tokens)) + break + else: # no break + # disable token evaluation by final() to allow us to get a ComputationNode; + # this makes long parses very slow, however, so once a greedy parse action + # is hit such that evaluate_tokens gets called, evaluate_tokens will set + # final_evaluate_tokens.enabled back to True, which speeds up the rest of the + # parse and tells us that something greedy happened so we can't cache + final_evaluate_tokens.enabled = False + try: + with parsing_context(inner): + loc, tokens = prep_grammar(grammar, for_scan=False).parseString(text, returnLoc=True) + if not final_evaluate_tokens.enabled: + prefix = text[:loc + 1] + is_at_end = loc >= len(text) + computation_graph_cache[(prefix, is_at_end)] = tokens + finally: + if DEVELOP: + logger.record_stat("cached_parse", False) + logger.log_tag( + "cached_parse miss " + ("-> stored" if not final_evaluate_tokens.enabled else "(not stored)"), + text, + multiline=True, + ) + final_evaluate_tokens.enabled = True + + if eval_parse_tree: + tokens = unpack(tokens) + return tokens + + +def try_parse(grammar, text, inner=None, eval_parse_tree=True, computation_graph_cache=None): """Attempt to parse text using grammar else None.""" try: - return parse(grammar, text, inner, eval_parse_tree) + if computation_graph_cache is None: + return parse(grammar, text, inner, eval_parse_tree) + else: + return cached_parse(computation_graph_cache, grammar, text, inner, eval_parse_tree) except ParseBaseException: return None -def does_parse(grammar, text, inner=None): +def does_parse(grammar, text, inner=None, **kwargs): """Determine if text can be parsed using grammar.""" - return try_parse(grammar, text, inner, eval_parse_tree=False) + return try_parse(grammar, text, inner, eval_parse_tree=False, **kwargs) def all_matches(grammar, text, inner=None, eval_parse_tree=True): @@ -1370,6 +1427,8 @@ def parseImpl(self, original, loc, *args, **kwargs): with self.wrapped_context(): parse_loc, tokens = super(Wrap, self).parseImpl(original, loc, *args, **kwargs) if self.greedy: + if logger.tracing and not final_evaluate_tokens.enabled: + logger.log_tag("cached_parse invalidated by", self) tokens = evaluate_tokens(tokens) if reparse and parse_loc is None: raise CoconutInternalException("illegal double reparse in", self) diff --git a/coconut/constants.py b/coconut/constants.py index 243bca2a..29941a54 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -1050,7 +1050,7 @@ def get_path_env_var(env_var, default): # min versions are inclusive unpinned_min_versions = { - "cPyparsing": (2, 4, 7, 2, 4, 0), + "cPyparsing": (2, 4, 7, 2, 4, 1), ("pre-commit", "py3"): (3,), ("psutil", "py>=27"): (6,), "jupyter": (1, 1), diff --git a/coconut/terminal.py b/coconut/terminal.py index 3ff7d432..ed94ad4c 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -561,7 +561,10 @@ def trace(self, item): return item def record_stat(self, stat_name, stat_bool): - """Record the given boolean statistic for the given stat_name.""" + """Record the given boolean statistic for the given stat_name. + + All stats recorded here must have some printing logic added to gather_parsing_stats or log_compiler_stats. + Printed stats should also be added to the regex in the Makefile for getting non-informational lines.""" self.recorded_stats[stat_name][stat_bool] += 1 @contextmanager @@ -569,6 +572,7 @@ def gather_parsing_stats(self): """Times parsing if --verbose.""" if self.verbose: self.recorded_stats.pop("adaptive", None) + self.recorded_stats.pop("cached_parse", None) start_time = get_clock_time() try: yield @@ -584,6 +588,9 @@ def gather_parsing_stats(self): if "adaptive" in self.recorded_stats: failures, successes = self.recorded_stats["adaptive"] self.printlog("\tAdaptive parsing stats:", successes, "successes;", failures, "failures") + if "cached_parse" in self.recorded_stats: + misses, hits = self.recorded_stats["cached_parse"] + self.printlog("\tComputation graph cache stats:", hits, "hits;", misses, "misses") if maybe_make_safe is not None: hits, misses = maybe_make_safe.stats self.printlog("\tErrorless parsing stats:", hits, "errorless;", misses, "with errors") @@ -595,10 +602,9 @@ def log_compiler_stats(self, comp): if self.verbose: self.log("Grammar init time: " + str(comp.grammar_init_time) + " secs / Total init time: " + str(get_clock_time() - first_import_time) + " secs") for stat_name, (no_copy, yes_copy) in self.recorded_stats.items(): - if not stat_name.startswith("maybe_copy_"): - continue - name = assert_remove_prefix(stat_name, "maybe_copy_") - self.printlog("\tGrammar copying stats (" + name + "):", no_copy, "not copied;", yes_copy, "copied") + if stat_name.startswith("maybe_copy_"): + name = assert_remove_prefix(stat_name, "maybe_copy_") + self.printlog("\tGrammar copying stats (" + name + "):", no_copy, "not copied;", yes_copy, "copied") total_block_time = defaultdict(int)