Skip to content

Commit 803e62f

Browse files
author
Ervin T
authored
Clear agent processor properly on episode reset (#3437)
1 parent ff99fb0 commit 803e62f

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

ml-agents/mlagents/trainers/agent_processor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def end_episode(self) -> None:
203203
Ends the episode, terminating the current trajectory and stopping stats collection for that
204204
episode. Used for forceful reset (e.g. in curriculum or generalization training.)
205205
"""
206-
self.experience_buffers.clear()
207-
self.episode_rewards.clear()
208-
self.episode_steps.clear()
206+
all_gids = list(self.experience_buffers.keys()) # Need to make copy
207+
for _gid in all_gids:
208+
self._clean_agent_data(_gid)
209209

210210

211211
class AgentManagerQueue(Generic[T]):

ml-agents/mlagents/trainers/tests/test_agent_processor.py

+53
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,59 @@ def test_agent_deletion():
154154
assert len(processor.episode_rewards.keys()) == 0
155155

156156

157+
def test_end_episode():
158+
policy = create_mock_policy()
159+
tqueue = mock.Mock()
160+
name_behavior_id = "test_brain_name"
161+
processor = AgentProcessor(
162+
policy,
163+
name_behavior_id,
164+
max_trajectory_length=5,
165+
stats_reporter=StatsReporter("testcat"),
166+
)
167+
168+
fake_action_outputs = {
169+
"action": [0.1],
170+
"entropy": np.array([1.0], dtype=np.float32),
171+
"learning_rate": 1.0,
172+
"pre_action": [0.1],
173+
"log_probs": [0.1],
174+
}
175+
mock_step = mb.create_mock_batchedstep(
176+
num_agents=1,
177+
num_vector_observations=8,
178+
action_shape=[2],
179+
num_vis_observations=0,
180+
)
181+
fake_action_info = ActionInfo(
182+
action=[0.1],
183+
value=[0.1],
184+
outputs=fake_action_outputs,
185+
agent_ids=mock_step.agent_id,
186+
)
187+
188+
processor.publish_trajectory_queue(tqueue)
189+
# This is like the initial state after the env reset
190+
processor.add_experiences(mock_step, 0, ActionInfo.empty())
191+
# Run 3 trajectories, with different workers (to simulate different agents)
192+
remove_calls = []
193+
for _ep in range(3):
194+
remove_calls.append(mock.call([get_global_agent_id(_ep, 0)]))
195+
for _ in range(5):
196+
processor.add_experiences(mock_step, _ep, fake_action_info)
197+
# Make sure we don't add experiences from the prior agents after the done
198+
199+
# Call end episode
200+
processor.end_episode()
201+
# Check that we removed every agent
202+
policy.remove_previous_action.assert_has_calls(remove_calls)
203+
# Check that there are no experiences left
204+
assert len(processor.experience_buffers.keys()) == 0
205+
assert len(processor.last_take_action_outputs.keys()) == 0
206+
assert len(processor.episode_steps.keys()) == 0
207+
assert len(processor.episode_rewards.keys()) == 0
208+
209+
157210
def test_agent_manager():
158211
policy = create_mock_policy()
159212
name_behavior_id = "test_brain_name"

0 commit comments

Comments
 (0)