-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathvideo_plot_springcartpole.py
48 lines (37 loc) · 1.38 KB
/
video_plot_springcartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
===============================================
A demo of SpringCartPole environment with DQNAgent
===============================================
Illustration of the training and video rendering of DQN Agent in
SpringCartPole environment.
Agent is slightly tuned, but not optimal. This is just for illustration purpose.
.. video:: ../../video_plot_springcartpole.mp4
:width: 600
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_springcartpole.jpg'
from rlberry_research.envs.classic_control import SpringCartPole
from rlberry_research.agents.torch import DQNAgent
from gymnasium.wrappers.time_limit import TimeLimit
model_configs = {
"type": "MultiLayerPerceptron",
"layer_sizes": (256, 256),
"reshape": False,
}
init_kwargs = dict(
q_net_constructor="rlberry_research.agents.torch.utils.training.model_factory_from_env",
q_net_kwargs=model_configs,
)
env = SpringCartPole(obs_trans=False, swing_up=True)
env = TimeLimit(env, max_episode_steps=500)
agent = DQNAgent(env, **init_kwargs)
agent.fit(budget=1e5)
env.enable_rendering()
observation, info = env.reset()
for tt in range(1000):
action = agent.policy(observation)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if done:
observation, info = env.reset()
# Save video
video = env.save_video("_video/video_plot_springcartpole.mp4")