@@ -154,6 +154,59 @@ def test_agent_deletion():
154
154
assert len (processor .episode_rewards .keys ()) == 0
155
155
156
156
157
+ def test_end_episode ():
158
+ policy = create_mock_policy ()
159
+ tqueue = mock .Mock ()
160
+ name_behavior_id = "test_brain_name"
161
+ processor = AgentProcessor (
162
+ policy ,
163
+ name_behavior_id ,
164
+ max_trajectory_length = 5 ,
165
+ stats_reporter = StatsReporter ("testcat" ),
166
+ )
167
+
168
+ fake_action_outputs = {
169
+ "action" : [0.1 ],
170
+ "entropy" : np .array ([1.0 ], dtype = np .float32 ),
171
+ "learning_rate" : 1.0 ,
172
+ "pre_action" : [0.1 ],
173
+ "log_probs" : [0.1 ],
174
+ }
175
+ mock_step = mb .create_mock_batchedstep (
176
+ num_agents = 1 ,
177
+ num_vector_observations = 8 ,
178
+ action_shape = [2 ],
179
+ num_vis_observations = 0 ,
180
+ )
181
+ fake_action_info = ActionInfo (
182
+ action = [0.1 ],
183
+ value = [0.1 ],
184
+ outputs = fake_action_outputs ,
185
+ agent_ids = mock_step .agent_id ,
186
+ )
187
+
188
+ processor .publish_trajectory_queue (tqueue )
189
+ # This is like the initial state after the env reset
190
+ processor .add_experiences (mock_step , 0 , ActionInfo .empty ())
191
+ # Run 3 trajectories, with different workers (to simulate different agents)
192
+ remove_calls = []
193
+ for _ep in range (3 ):
194
+ remove_calls .append (mock .call ([get_global_agent_id (_ep , 0 )]))
195
+ for _ in range (5 ):
196
+ processor .add_experiences (mock_step , _ep , fake_action_info )
197
+ # Make sure we don't add experiences from the prior agents after the done
198
+
199
+ # Call end episode
200
+ processor .end_episode ()
201
+ # Check that we removed every agent
202
+ policy .remove_previous_action .assert_has_calls (remove_calls )
203
+ # Check that there are no experiences left
204
+ assert len (processor .experience_buffers .keys ()) == 0
205
+ assert len (processor .last_take_action_outputs .keys ()) == 0
206
+ assert len (processor .episode_steps .keys ()) == 0
207
+ assert len (processor .episode_rewards .keys ()) == 0
208
+
209
+
157
210
def test_agent_manager ():
158
211
policy = create_mock_policy ()
159
212
name_behavior_id = "test_brain_name"
0 commit comments