-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathreplay.py
74 lines (57 loc) · 2.24 KB
/
replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import collections
import numpy as np
import random
class ExperienceBuffer:
"""
Regular experience replay buffer (i.e. store individual experiences).
"""
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity)
def __len__(self):
return len(self.buffer)
def append(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = random.sample(range(len(self.buffer)), batch_size)
return [self.buffer[x] for x in indices]
def clear(self):
self.buffer.clear()
class EpisodeExperienceBuffer:
"""
Experience replay buffer that stores and samples sequences of experiences within episodes.
"""
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity)
def __len__(self):
return len(self.buffer)
def append(self, episode):
self.buffer.append(episode)
def sample(self, batch_size, seq_length):
states = []
observations = []
actions = []
next_states = []
next_observations = []
rewards = []
is_terminals = []
episode_ids = random.sample(range(len(self.buffer)), batch_size)
for episode_id in episode_ids:
episode_length = len(self.buffer[episode_id])
if seq_length >= episode_length: # use the full episode
start_ts = 0
end_ts = episode_length
else: # use a subsequence of the episode (but trying to maximize its length for history accuracy)
start_ts = np.random.randint(0, episode_length - seq_length)
end_ts = start_ts + seq_length
sequence = self.buffer[episode_id][start_ts:end_ts]
state, obs, action, next_state, next_obs, reward, is_terminal = zip(*sequence)
states.append(state)
observations.append(obs)
actions.append(action)
next_states.append(next_state)
next_observations.append(next_obs)
rewards.append(reward)
is_terminals.append(is_terminal)
return states, observations, actions, next_states, next_observations, rewards, is_terminals,
def clear(self):
self.buffer.clear()