Skip to content

Commit

Permalink
Improve perf and prep for comp graph caching
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Oct 15, 2024
1 parent 74003ce commit 8d57286
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 56 deletions.
68 changes: 34 additions & 34 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class Compiler(Grammar, pickleable_obj):
]

def __init__(self, *args, **kwargs):
"""Creates a new compiler with the given parsing parameters."""
"""Create a new compiler with the given parsing parameters."""
self.setup(*args, **kwargs)
self.reset()

Expand Down Expand Up @@ -467,7 +467,7 @@ def setup(self, target=None, strict=False, minify=False, line_numbers=True, keep
self.no_wrap = no_wrap

def __reduce__(self):
"""Return pickling information."""
"""Get pickling information."""
return (self.__class__, (self.target, self.strict, self.minify, self.line_numbers, self.keep_lines, self.no_tco, self.no_wrap))

def get_cli_args(self):
Expand Down Expand Up @@ -644,6 +644,8 @@ def method(original, loc, tokens_or_item):
if trim_arity:
self_method = _trim_arity(self_method)
return self_method(original, loc, tokens_or_item)
if kwargs:
method.__name__ = py_str(method.__name__ + "$(" + ", ".join(str(k) + "=" + repr(v) for k, v in kwargs.items()) + ")")
internal_assert(
hasattr(cls_method, "ignore_arguments") is hasattr(method, "ignore_arguments")
and hasattr(cls_method, "ignore_no_tokens") is hasattr(method, "ignore_no_tokens")
Expand Down Expand Up @@ -1086,18 +1088,20 @@ def wrap_comment(self, text):
"""Wrap a comment."""
return "#" + self.add_ref("comment", text) + unwrapper

def wrap_error(self, error):
def wrap_error(self, error_maker):
"""Create a symbol that will raise the given error in postprocessing."""
return errwrapper + self.add_ref("error", error) + unwrapper
return errwrapper + self.add_ref("error_maker", error_maker) + unwrapper

def raise_or_wrap_error(self, error):
"""Raise if USE_COMPUTATION_GRAPH else wrap."""
def raise_or_wrap_error(self, *args, **kwargs):
"""Raise or defer if USE_COMPUTATION_GRAPH else wrap."""
error_maker = partial(self.make_err, *args, **kwargs)
if not USE_COMPUTATION_GRAPH:
return self.wrap_error(error)
return self.wrap_error(error_maker)
# differently-ordered any ofs can push these errors earlier than they should be, requiring us to defer them
elif use_adaptive_any_of or reverse_any_of:
return ExceptionNode(error)
return ExceptionNode(error_maker)
else:
raise error
raise error_maker()

def type_ignore_comment(self):
"""Get a "type: ignore" comment."""
Expand Down Expand Up @@ -2742,7 +2746,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names=
pre_err_line, err_line = raw_line.split(errwrapper, 1)
err_ref, post_err_line = err_line.split(unwrapper, 1)
if not ignore_errors:
raise self.get_ref("error", err_ref)
raise self.get_ref("error_maker", err_ref)()
raw_line = pre_err_line + " " + post_err_line

# look for functions
Expand Down Expand Up @@ -4890,6 +4894,7 @@ def where_stmt_handle(self, loc, tokens):

where_assigns = self.current_parsing_context("where")["assigns"]
internal_assert(lambda: where_assigns is not None, "missing where_assigns")
print(where_assigns)

where_init = "".join(body_stmts)
where_final = main_stmt + "\n"
Expand Down Expand Up @@ -4989,7 +4994,8 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
if self.disable_name_check:
return name

if assign:
# register non-mid-expression variable assignments inside of where statements for later mangling
if assign and not expr_setname:
where_context = self.current_parsing_context("where")
if where_context is not None:
where_assigns = where_context["assigns"]
Expand Down Expand Up @@ -5020,13 +5026,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
if typevar_info["typevar_locs"].get(name, None) != loc:
if assign:
return self.raise_or_wrap_error(
self.make_err(
CoconutSyntaxError,
"cannot reassign type variable '{name}'".format(name=name),
original,
loc,
extra="use explicit '\\{name}' syntax if intended".format(name=name),
),
CoconutSyntaxError,
"cannot reassign type variable '{name}'".format(name=name),
original,
loc,
extra="use explicit '\\{name}' syntax if intended".format(name=name),
)
return typevars[name]

Expand Down Expand Up @@ -5057,13 +5061,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
return name
elif assign:
return self.raise_or_wrap_error(
self.make_err(
CoconutTargetError,
"found Python-3-only assignment to 'exec' as a variable name",
original,
loc,
target="3",
),
CoconutTargetError,
"found Python-3-only assignment to 'exec' as a variable name",
original,
loc,
target="3",
)
else:
return "_coconut_exec"
Expand All @@ -5076,12 +5078,10 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
return name
elif not escaped and name.startswith(reserved_prefix) and name not in self.operators:
return self.raise_or_wrap_error(
self.make_err(
CoconutSyntaxError,
"variable names cannot start with reserved prefix '{prefix}' (use explicit '\\{name}' syntax if intending to access Coconut internals)".format(prefix=reserved_prefix, name=name),
original,
loc,
),
CoconutSyntaxError,
"variable names cannot start with reserved prefix '{prefix}' (use explicit '\\{name}' syntax if intending to access Coconut internals)".format(prefix=reserved_prefix, name=name),
original,
loc,
)
else:
return name
Expand All @@ -5104,7 +5104,7 @@ def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, alw
else:
if always_warn:
kwargs["extra"] = "remove --strict to downgrade to a warning"
return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs))
return self.raise_or_wrap_error(CoconutStyleError, message, original, loc, **kwargs)
elif always_warn:
self.syntax_warning(message, original, loc)
return tokens[0]
Expand Down Expand Up @@ -5145,13 +5145,13 @@ def check_py(self, version, name, original, loc, tokens):
self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens)
version_info = get_target_info(version)
if self.target_info < version_info:
return self.raise_or_wrap_error(self.make_err(
return self.raise_or_wrap_error(
CoconutTargetError,
"found Python " + ".".join(str(v) for v in version_info) + " " + name,
original,
loc,
target=version,
))
)
else:
return tokens[0]

Expand Down
5 changes: 2 additions & 3 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,7 +2616,6 @@ class Grammar(object):
decoratable_data_stmt,
match_stmt,
passthrough_stmt,
where_stmt,
)

flow_stmt = any_of(
Expand Down Expand Up @@ -2661,8 +2660,8 @@ class Grammar(object):
stmt <<= final(
compound_stmt
| simple_stmt # includes destructuring
# must be after destructuring due to ambiguity
| cases_stmt
| cases_stmt # must be after destructuring due to ambiguity
| where_stmt # slows down parsing when put before simple_stmt
# at the very end as a fallback case for the anything parser
| anything_stmt
)
Expand Down
54 changes: 36 additions & 18 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def evaluate_tokens(tokens, **kwargs):
elif isinstance(tokens, ComputationNode):
result = tokens.evaluate()
if is_final and isinstance(result, ExceptionNode):
raise result.exception
result.evaluate()
elif isinstance(result, ParseResults):
return make_modified_tokens(result, cls=MergeNode)
elif isinstance(result, list):
Expand All @@ -286,7 +286,7 @@ def evaluate_tokens(tokens, **kwargs):

elif isinstance(tokens, ExceptionNode):
if is_final:
raise tokens.exception
tokens.evaluate()
return tokens

elif isinstance(tokens, DeferredNode):
Expand Down Expand Up @@ -321,9 +321,12 @@ def build_new_toks_for(tokens, new_toklist, unchanged=False):
return new_toklist


cached_trim_arity = memoize()(_trim_arity)


class ComputationNode(object):
"""A single node in the computation graph."""
__slots__ = ("action", "original", "loc", "tokens")
__slots__ = ("action", "original", "loc", "tokens", "trim_arity")
pprinting = False
override_original = None
add_to_loc = 0
Expand All @@ -339,7 +342,7 @@ def using_overrides(cls):
cls.override_original = override_original
cls.add_to_loc = add_to_loc

def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_one_token=False, greedy=False, trim_arity=True):
def __new__(cls, action, original, loc, tokens, trim_arity=True, ignore_no_tokens=False, ignore_one_token=False, greedy=False):
"""Create a ComputionNode to return from a parse action.
If ignore_no_tokens, then don't call the action if there are no tokens.
Expand All @@ -350,18 +353,20 @@ def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_o
return build_new_toks_for(tokens, tokens, unchanged=True)
else:
self = super(ComputationNode, cls).__new__(cls)
if trim_arity:
self.action = _trim_arity(action)
else:
self.action = action
self.original = original if self.override_original is None else self.override_original
self.loc = self.add_to_loc + loc
self.action = action
self.original = original
self.loc = loc
self.tokens = tokens
self.trim_arity = trim_arity
if greedy:
return self.evaluate()
else:
return self

def __reduce__(self):
"""Get pickling information."""
return (self.__class__, (self.action, self.original, self.loc, self.tokens, self.trim_arity))

@property
def name(self):
"""Get the name of the action."""
Expand All @@ -377,15 +382,23 @@ def evaluate(self):
# to actually be reevaluated
if logger.tracing and not final_evaluate_tokens.enabled:
logger.log_tag("cached_parse invalidated by", self)

if self.trim_arity:
using_action = cached_trim_arity(self.action)
else:
using_action = self.action
using_original = self.original if self.override_original is None else self.override_original
using_loc = self.loc + self.add_to_loc
evaluated_toks = evaluate_tokens(self.tokens)

if logger.tracing: # avoid the overhead of the call if not tracing
logger.log_trace(self.name, self.original, self.loc, evaluated_toks, self.tokens)
logger.log_trace(self.name, using_original, using_loc, evaluated_toks, self.tokens)
if isinstance(evaluated_toks, ExceptionNode):
return evaluated_toks # short-circuit if we got an ExceptionNode
try:
result = self.action(
self.original,
self.loc,
result = using_action(
using_original,
using_loc,
evaluated_toks,
)
except CoconutException:
Expand All @@ -398,6 +411,7 @@ def evaluate(self):
embed(depth=2)
else:
raise error

out = build_new_toks_for(evaluated_toks, result)
if logger.tracing: # avoid the overhead if not tracing
dropped_keys = set(self.tokens._ParseResults__tokdict.keys())
Expand Down Expand Up @@ -434,12 +448,16 @@ def evaluate(self):

class ExceptionNode(object):
"""A node in the computation graph that stores an exception that will be raised upon final evaluation."""
__slots__ = ("exception",)
__slots__ = ("exception_maker",)

def __init__(self, exception):
def __init__(self, exception_maker):
if not USE_COMPUTATION_GRAPH:
raise exception
self.exception = exception
raise exception_maker()
self.exception_maker = exception_maker

def evaluate(self):
"""Raise the stored exception."""
raise self.exception_maker()


class CombineToNode(Combine):
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.2"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 1
DEVELOP = 2
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down

0 comments on commit 8d57286

Please sign in to comment.