Skip to content

Commit

Permalink
Improve 'set_wrapper_attr'.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Jan 18, 2025
1 parent 08a28d3 commit 5d7de3e
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,12 @@ def get_wrapper_attr(self, name: str) -> Any:
"""Gets the attribute `name` from the environment."""
return getattr(self, name)

def set_wrapper_attr(self, name: str, value: Any):
def set_wrapper_attr(self, name: str, value: Any, *, force: bool = False):
"""Sets the attribute `name` on the environment with `value`."""
setattr(self, name, value)
if force or hasattr(self, name):
setattr(self, name, value)
else:
raise AttributeError(f"{self} has no attribute {name!r}")


WrapperObsType = TypeVar("WrapperObsType")
Expand Down Expand Up @@ -425,30 +428,27 @@ def get_wrapper_attr(self, name: str) -> Any:
f"wrapper {self.class_name()} has no attribute {name!r}"
) from e

def set_wrapper_attr(self, name: str, value: Any):
def set_wrapper_attr(self, name: str, value: Any, *, force: bool = True):
"""Sets an attribute on this wrapper or lower environment if `name` is already defined.
Args:
name: The variable name
value: The new variable value
force: Whether to create the attribute on this wrapper if it does not exists on the
lower environment instead of raising an exception
"""
sub_env = self

# loop through all the wrappers, checking if it has the variable name then setting it
# otherwise stripping the wrapper to check the next.
# end when the core env is reached
while isinstance(sub_env, Wrapper):
if hasattr(sub_env, name):
setattr(sub_env, name, value)
return

sub_env = sub_env.env

# check if the base environment has the wrapper, otherwise, we set it on the top (this) wrapper
if hasattr(sub_env, name):
setattr(sub_env, name, value)
else:
if hasattr(self, name):
setattr(self, name, value)
else:
try:
self.env.set_wrapper_attr(name, value, force=False)
except AttributeError as e:
if force:
setattr(self, name, value)
else:
raise AttributeError(
f"wrapper {self.class_name()} has no attribute {name!r}"
) from e

def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
Expand Down

0 comments on commit 5d7de3e

Please sign in to comment.