Skip to content

Commit 4867b10

Browse files
committedJun 23, 2019
update
0 parents  commit 4867b10

File tree

8 files changed

+285
-0
lines changed

8 files changed

+285
-0
lines changed
 

‎DQN.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torch.nn.functional as F
5+
6+
class DQN(nn.Module):
7+
def __init__(self, h, w):
8+
super(DQN, self).__init__()
9+
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
10+
self.bn1 = nn.BatchNorm2d(16)
11+
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
12+
self.bn2 = nn.BatchNorm2d(32)
13+
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
14+
self.bn3 = nn.BatchNorm2d(32)
15+
16+
def conv2d_size_out(size, kernel_size=5, stride=2):
17+
return (size - kernel_size) // stride + 1
18+
convw = w
19+
convh = h
20+
for i in range(3):
21+
convw = conv2d_size_out(convw)
22+
convh = conv2d_size_out(convh)
23+
24+
self.head = nn.Linear(convw * convh * 32, 2)
25+
26+
27+
28+
# x.size: (N, input_channels, H, W)
29+
# output.size: (N, 2)
30+
# DQN is used to calculate Q(s_t)
31+
def forward(self, x):
32+
x = F.relu(self.bn1(self.conv1(x)))
33+
x = F.relu(self.bn2(self.conv2(x)))
34+
x = F.relu(self.bn3(self.conv3(x)))
35+
36+
return self.head(x.view(x.size(0), -1))
37+

‎DQN.pyc

1.62 KB
Binary file not shown.

‎README.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Deep Q-learning Network for Reinforcement Learning
2+
The re-implement of 《Reinforcement Learning(DQN) Tutorial》.

‎memory.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import random
2+
from collections import namedtuple
3+
4+
# define a tuple named Transition who have several fieldnames
5+
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
6+
7+
class ReplayMemory(object):
8+
def __init__(self, capacity):
9+
self.capacity = capacity
10+
self.memory = []
11+
self.position = 0
12+
13+
def push(self, *args):
14+
if len(self.memory) < self.capacity:
15+
self.memory.append(None)
16+
self.memory[self.position] = Transition(*args)
17+
# keep the recent self.capacity elements
18+
self.position = (self.position + 1) % self.capacity
19+
20+
def sample(self, batch_size):
21+
return random.sample(self.memory, batch_size)
22+
23+
def __len__(self):
24+
return len(self.memory)
25+

‎memory.pyc

1.36 KB
Binary file not shown.

‎train.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import random
2+
import math
3+
from itertools import count
4+
5+
import gym
6+
import torch
7+
import torch.optim as optim
8+
import torch.nn.functional as F
9+
import matplotlib.pyplot as plt
10+
11+
from DQN import DQN
12+
from utils import get_screen
13+
from utils import plot_durations
14+
from memory import ReplayMemory
15+
from memory import Transition
16+
17+
def select_action(state):
18+
global steps_done
19+
sample = random.random()
20+
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
21+
math.exp(-steps_done / float(EPS_DECAY))
22+
steps_done += 1
23+
24+
# random strategy: at begining always take the random strategy
25+
if sample < eps_threshold:
26+
return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
27+
else:
28+
return policy_net(state).max(1)[1].view(1,1)
29+
30+
31+
def optimize_model(policy_net, optimizer):
32+
# first sample a batch
33+
if len(memory) < BATCH_SIZE:
34+
return
35+
transitions = memory.sample(BATCH_SIZE)
36+
batch = Transition(*zip(*transitions))
37+
# non_final_mask is the mask to tag all the item whose next_state is not None as True
38+
non_final_mask = tuple(map(lambda s: s is not None, batch.next_state))
39+
non_final_mask = torch.tensor(non_final_mask, device=device, dtype=torch.uint8)
40+
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
41+
42+
state_batch = torch.cat(batch.state)
43+
action_batch = torch.cat(batch.action)
44+
reward_batch = torch.cat(batch.reward)
45+
46+
# policy_net(state_batch) is used to get all value among all actions
47+
# gather method is used to get the value corresponding to certain action
48+
state_action_values = policy_net(state_batch).gather(1, action_batch)
49+
50+
next_state_values = torch.zeros(BATCH_SIZE, device=device)
51+
52+
# compute the V(s_{t+1}) for $s_{t+1}$ which is final state, we set V(s_{t+1}) = 0
53+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
54+
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
55+
56+
# Huber loss
57+
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
58+
59+
optimizer.zero_grad()
60+
loss.backward()
61+
for param in policy_net.parameters():
62+
param.grad.data.clamp_(-1, 1)
63+
optimizer.step()
64+
65+
env = gym.make('CartPole-v0').unwrapped
66+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67+
env.reset()
68+
69+
BATCH_SIZE = 128
70+
# GAMMA is the discount factor
71+
GAMMA = 0.999
72+
EPS_START = 0.9
73+
EPS_END = 0.05
74+
EPS_DECAY = 200
75+
76+
TARGET_UPDATE = 10
77+
78+
AVERAGE_SIZE = 10
79+
episode_durations = []
80+
81+
init_screen = get_screen(env, device)
82+
_, _, screen_height, screen_width = init_screen.shape
83+
84+
policy_net = DQN(screen_height, screen_width).to(device)
85+
target_net = DQN(screen_height, screen_width).to(device)
86+
87+
target_net.load_state_dict(policy_net.state_dict())
88+
target_net.eval()
89+
90+
optimizer = optim.RMSprop(policy_net.parameters())
91+
memory = ReplayMemory(10000)
92+
93+
steps_done = 0
94+
num_episodes = 300
95+
for i_episode in range(num_episodes):
96+
env.reset()
97+
last_screen = get_screen(env, device)
98+
current_screen = get_screen(env, device)
99+
state = current_screen - last_screen
100+
#print state
101+
for t in count():
102+
action = select_action(state)
103+
_, reward, done, _ = env.step(action.item())
104+
reward = torch.tensor([reward], device=device)
105+
106+
last_screen = current_screen
107+
current_screen = get_screen(env, device)
108+
109+
if not done:
110+
next_state = current_screen - last_screen
111+
else:
112+
next_state = None
113+
114+
memory.push(state, action, next_state, reward)
115+
116+
state = next_state
117+
#if done:
118+
# print "Episode Done"
119+
#else:
120+
# print state.size()
121+
optimize_model(policy_net, optimizer)
122+
if done:
123+
episode_durations.append(t+1)
124+
plot_durations(episode_durations, AVERAGE_SIZE)
125+
break
126+
127+
if i_episode % TARGET_UPDATE == 0:
128+
target_net.load_state_dict(policy_net.state_dict())
129+
130+
print("Complet")
131+
env.render()
132+
env.close()
133+
plt.ioff()
134+
plt.show()
135+
136+
137+

‎utils.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from PIL import Image
2+
3+
import gym
4+
import torch
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
import gym
8+
import torchvision.transforms as T
9+
import matplotlib.pyplot as plt
10+
11+
12+
# resize is several transforms composed together
13+
resize = T.Compose([
14+
T.ToPILImage(),
15+
T.Resize(40, interpolation=Image.CUBIC),
16+
T.ToTensor()
17+
])
18+
19+
def plot_durations(episode_durations, AVERAGE_SIZE):
20+
plt.figure(2)
21+
plt.clf()
22+
durations_t = torch.tensor(episode_durations, dtype=torch.float)
23+
24+
plt.title('Training ...')
25+
plt.xlabel('Episode')
26+
plt.ylabel('Duration')
27+
plt.plot(durations_t.numpy())
28+
29+
if len(durations_t) >= AVERAGE_SIZE:
30+
dim = 0
31+
size = AVERAGE_SIZE
32+
step = 1
33+
# duations_t.unfold(dim, size, step).size(): (no_point, 100)
34+
# duations_t.unfold(dim, size, step).mean(1).size(): (number_point, 1)
35+
means = durations_t.unfold(dim, size, step).mean(1).view(-1)
36+
means = torch.cat((torch.zeros(AVERAGE_SIZE-1), means))
37+
plt.plot(means.numpy())
38+
39+
plt.pause(0.001)
40+
#if is_ipython:
41+
# display.clear_output(wait=True)
42+
# display.display(plt.gcf())
43+
44+
# Anyway, it is used to extract the abscissa asis of the cart
45+
def get_cart_location(env, screen_width):
46+
world_width = env.x_threshold * 2
47+
scale = screen_width / world_width
48+
return int(env.state[0] * scale + screen_width / 2.0)
49+
50+
def get_screen(env, device):
51+
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
52+
_, screen_height, screen_width = screen.shape
53+
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
54+
55+
view_width = int(screen_width * 0.6)
56+
cart_location = get_cart_location(env, screen_width)
57+
58+
# slice usage: slice(stop) or slice(start, stop)
59+
# if in the left side
60+
if cart_location < view_width//2:
61+
slice_range = slice(view_width)
62+
# if in the right side
63+
elif cart_location > (screen_width - view_width // 2):
64+
slice_range = slice(-view_width, None)
65+
# if in the middle
66+
else:
67+
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
68+
screen = screen[:, :, slice_range]
69+
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
70+
71+
screen = torch.from_numpy(screen)
72+
# add a batch dimension: BCHW
73+
return resize(screen).unsqueeze(0).to(device)
74+
75+
if __name__ == '__main__':
76+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77+
env = gym.make('CartPole-v0').unwrapped
78+
env.reset()
79+
plt.figure()
80+
cart = get_screen(env).cpu().squeeze(0).permute(1, 2, 0).numpy()
81+
82+
plt.imshow(cart, interpolation='none')
83+
plt.title('Cart')
84+
plt.show()

‎utils.pyc

2.79 KB
Binary file not shown.

0 commit comments

Comments
 (0)
Please sign in to comment.