|
| 1 | +import numpy as np |
| 2 | +from copy import copy, deepcopy |
| 3 | +from itertools import product |
| 4 | + |
| 5 | +from envs.env import Env, Direction |
| 6 | + |
| 7 | + |
| 8 | +class ApplesState(object): |
| 9 | + ''' |
| 10 | + state of the environment; describes positions of all objects in the env. |
| 11 | + ''' |
| 12 | + def __init__(self, agent_pos, tree_states, bucket_states, carrying_apple): |
| 13 | + """ |
| 14 | + agent_pos: (orientation, x, y) tuple for the agent's location |
| 15 | + tree_states: Dictionary mapping (x, y) tuples to booleans. |
| 16 | + bucket_states: Dictionary mapping (x, y) tuples to integers. |
| 17 | + carrying_apple: Boolean, True if carrying an apple, False otherwise. |
| 18 | + """ |
| 19 | + self.agent_pos = agent_pos |
| 20 | + self.tree_states = tree_states |
| 21 | + self.bucket_states = bucket_states |
| 22 | + self.carrying_apple = carrying_apple |
| 23 | + |
| 24 | + def __eq__(self, other): |
| 25 | + return isinstance(other, ApplesState) and \ |
| 26 | + self.agent_pos == other.agent_pos and \ |
| 27 | + self.tree_states == other.tree_states and \ |
| 28 | + self.bucket_states == other.bucket_states and \ |
| 29 | + self.carrying_apple == other.carrying_apple |
| 30 | + |
| 31 | + def __hash__(self): |
| 32 | + def get_vals(dictionary): |
| 33 | + return tuple([dictionary[loc] for loc in sorted(dictionary.keys())]) |
| 34 | + return hash(self.agent_pos + get_vals(self.tree_states) + get_vals(self.bucket_states) + (self.carrying_apple,)) |
| 35 | + |
| 36 | + |
| 37 | +class ApplesEnv(Env): |
| 38 | + def __init__(self, spec, compute_transitions=True): |
| 39 | + """ |
| 40 | + height: Integer, height of the grid. Y coordinates are in [0, height). |
| 41 | + width: Integer, width of the grid. X coordinates are in [0, width). |
| 42 | + init_state: ApplesState, initial state of the environment |
| 43 | + vase_locations: List of (x, y) tuples, locations of vases |
| 44 | + num_vases: Integer, number of vases |
| 45 | + carpet_locations: Set of (x, y) tuples, locations of carpets |
| 46 | + feature_locations: List of (x, y) tuples, locations of features |
| 47 | + s: ApplesState, Current state |
| 48 | + nA: Integer, number of actions |
| 49 | + """ |
| 50 | + self.height = spec.height |
| 51 | + self.width = spec.width |
| 52 | + self.apple_regen_probability = spec.apple_regen_probability |
| 53 | + self.bucket_capacity = spec.bucket_capacity |
| 54 | + self.init_state = deepcopy(spec.init_state) |
| 55 | + self.include_location_features = spec.include_location_features |
| 56 | + |
| 57 | + self.tree_locations = list(self.init_state.tree_states.keys()) |
| 58 | + self.bucket_locations = list(self.init_state.bucket_states.keys()) |
| 59 | + used_locations = set(self.tree_locations + self.bucket_locations) |
| 60 | + self.possible_agent_locations = list(filter( |
| 61 | + lambda pos: pos not in used_locations, |
| 62 | + product(range(self.width), range(self.height)))) |
| 63 | + |
| 64 | + self.num_trees = len(self.tree_locations) |
| 65 | + self.num_buckets = len(self.bucket_locations) |
| 66 | + |
| 67 | + self.default_action = Direction.get_number_from_direction(Direction.STAY) |
| 68 | + self.nA = 6 |
| 69 | + self.num_features = len(self.s_to_f(self.init_state)) |
| 70 | + |
| 71 | + self.reset() |
| 72 | + |
| 73 | + if compute_transitions: |
| 74 | + states = self.enumerate_states() |
| 75 | + self.make_transition_matrices( |
| 76 | + states, range(self.nA), self.nS, self.nA) |
| 77 | + self.make_f_matrix(self.nS, self.num_features) |
| 78 | + |
| 79 | + |
| 80 | + def enumerate_states(self): |
| 81 | + all_agent_positions = filter( |
| 82 | + lambda pos: (pos[1], pos[2]) in self.possible_agent_locations, |
| 83 | + product(range(4), range(self.width), range(self.height))) |
| 84 | + all_tree_states = map( |
| 85 | + lambda tree_vals: dict(zip(self.tree_locations, tree_vals)), |
| 86 | + product([True, False], repeat=self.num_trees)) |
| 87 | + all_bucket_states = map( |
| 88 | + lambda bucket_vals: dict(zip(self.bucket_locations, bucket_vals)), |
| 89 | + product(range(self.bucket_capacity + 1), repeat=self.num_buckets)) |
| 90 | + all_states = map( |
| 91 | + lambda x: ApplesState(*x), |
| 92 | + product(all_agent_positions, all_tree_states, all_bucket_states, [True, False])) |
| 93 | + |
| 94 | + state_num = {} |
| 95 | + for state in all_states: |
| 96 | + if state not in state_num: |
| 97 | + state_num[state] = len(state_num) |
| 98 | + |
| 99 | + self.state_num = state_num |
| 100 | + self.num_state = {v: k for k, v in self.state_num.items()} |
| 101 | + self.nS = len(state_num) |
| 102 | + |
| 103 | + return state_num.keys() |
| 104 | + |
| 105 | + def get_num_from_state(self, state): |
| 106 | + return self.state_num[state] |
| 107 | + |
| 108 | + def get_state_from_num(self, num): |
| 109 | + return self.num_state[num] |
| 110 | + |
| 111 | + |
| 112 | + def s_to_f(self, s): |
| 113 | + ''' |
| 114 | + Returns features of the state: |
| 115 | + - Number of apples in buckets |
| 116 | + - Number of apples on trees |
| 117 | + - Whether the agent is carrying an apple |
| 118 | + - For each other location, whether the agent is on that location |
| 119 | + ''' |
| 120 | + num_bucket_apples = sum(s.bucket_states.values()) |
| 121 | + num_tree_apples = sum(map(int, s.tree_states.values())) |
| 122 | + carrying_apple = int(s.carrying_apple) |
| 123 | + agent_pos = s.agent_pos[1], s.agent_pos[2] # Drop orientation |
| 124 | + features = [num_bucket_apples, num_tree_apples, carrying_apple] |
| 125 | + if self.include_location_features: |
| 126 | + features = features + [int(agent_pos == pos) for pos in self.possible_agent_locations] |
| 127 | + return np.array(features) |
| 128 | + |
| 129 | + |
| 130 | + def get_next_states(self, state, action): |
| 131 | + '''returns the next state given a state and an action''' |
| 132 | + action = int(action) |
| 133 | + orientation, x, y = state.agent_pos |
| 134 | + new_orientation, new_x, new_y = state.agent_pos |
| 135 | + new_tree_states = deepcopy(state.tree_states) |
| 136 | + new_bucket_states = deepcopy(state.bucket_states) |
| 137 | + new_carrying_apple = state.carrying_apple |
| 138 | + |
| 139 | + if action == Direction.get_number_from_direction(Direction.STAY): |
| 140 | + pass |
| 141 | + elif action < len(Direction.ALL_DIRECTIONS): |
| 142 | + new_orientation = action |
| 143 | + move_x, move_y = Direction.move_in_direction_number((x, y), action) |
| 144 | + # New position is legal |
| 145 | + if (0 <= move_x < self.width and \ |
| 146 | + 0 <= move_y < self.height and \ |
| 147 | + (move_x, move_y) in self.possible_agent_locations): |
| 148 | + new_x, new_y = move_x, move_y |
| 149 | + else: |
| 150 | + # Move only changes orientation, which we already handled |
| 151 | + pass |
| 152 | + elif action == 5: |
| 153 | + obj_pos = Direction.move_in_direction_number((x, y), orientation) |
| 154 | + if state.carrying_apple: |
| 155 | + # We always drop the apple |
| 156 | + new_carrying_apple = False |
| 157 | + # If we're facing a bucket, it goes there |
| 158 | + if obj_pos in new_bucket_states: |
| 159 | + prev_apples = new_bucket_states[obj_pos] |
| 160 | + new_bucket_states[obj_pos] = min(prev_apples + 1, self.bucket_capacity) |
| 161 | + elif obj_pos in new_tree_states and new_tree_states[obj_pos]: |
| 162 | + new_carrying_apple = True |
| 163 | + new_tree_states[obj_pos] = False |
| 164 | + else: |
| 165 | + # Interact while holding nothing and not facing a tree. |
| 166 | + pass |
| 167 | + else: |
| 168 | + raise ValueError('Invalid action {}'.format(action)) |
| 169 | + |
| 170 | + new_pos = new_orientation, new_x, new_y |
| 171 | + |
| 172 | + def make_state(prob_apples_tuple): |
| 173 | + prob, tree_apples = prob_apples_tuple |
| 174 | + trees = dict(zip(self.tree_locations, tree_apples)) |
| 175 | + s = ApplesState(new_pos, trees, new_bucket_states, new_carrying_apple) |
| 176 | + return (prob, s, 0) |
| 177 | + |
| 178 | + # For apple regeneration, don't regenerate apples that were just picked, |
| 179 | + # so use the apple booleans from the original state |
| 180 | + old_tree_apples = [state.tree_states[loc] for loc in self.tree_locations] |
| 181 | + new_tree_apples = [new_tree_states[loc] for loc in self.tree_locations] |
| 182 | + return list(map(make_state, self.regen_apples(old_tree_apples, new_tree_apples))) |
| 183 | + |
| 184 | + def regen_apples(self, old_tree_apples, new_tree_apples): |
| 185 | + if len(old_tree_apples) == 0: |
| 186 | + yield (1, []) |
| 187 | + return |
| 188 | + for prob, apples in self.regen_apples(old_tree_apples[1:], new_tree_apples[1:]): |
| 189 | + if old_tree_apples[0]: |
| 190 | + yield prob, [new_tree_apples[0]] + apples |
| 191 | + else: |
| 192 | + yield prob * self.apple_regen_probability, [True] + apples |
| 193 | + yield prob * (1 - self.apple_regen_probability), [False] + apples |
| 194 | + |
| 195 | + |
| 196 | + def print_state(self, state): |
| 197 | + '''Renders the state.''' |
| 198 | + h, w = self.height, self.width |
| 199 | + canvas = np.zeros(tuple([2*h-1, 2*w+1]), dtype='int8') |
| 200 | + |
| 201 | + # cell borders |
| 202 | + for y in range(1, canvas.shape[0], 2): |
| 203 | + canvas[y, :] = 1 |
| 204 | + for x in range(0, canvas.shape[1], 2): |
| 205 | + canvas[:, x] = 2 |
| 206 | + |
| 207 | + # trees |
| 208 | + for (x, y), has_apple in state.tree_states.items(): |
| 209 | + canvas[2*y, 2*x+1] = 3 if has_apple else 4 |
| 210 | + |
| 211 | + for x, y in self.bucket_locations: |
| 212 | + canvas[2*y, 2*x+1] = 5 |
| 213 | + |
| 214 | + # agent |
| 215 | + orientation, x, y = state.agent_pos |
| 216 | + canvas[2*y, 2*x+1] = 6 |
| 217 | + |
| 218 | + black_color = '\x1b[0m' |
| 219 | + purple_background_color = '\x1b[0;35;85m' |
| 220 | + |
| 221 | + for line in canvas: |
| 222 | + for char_num in line: |
| 223 | + if char_num==0: |
| 224 | + print('\u2003', end='') |
| 225 | + elif char_num==1: |
| 226 | + print('─', end='') |
| 227 | + elif char_num==2: |
| 228 | + print('│', end='') |
| 229 | + elif char_num==3: |
| 230 | + print('\x1b[0;32;85m█'+black_color , end='') |
| 231 | + elif char_num==4: |
| 232 | + print('\033[91m█'+black_color, end='') |
| 233 | + elif char_num==5: |
| 234 | + print('\033[93m█'+black_color, end='') |
| 235 | + elif char_num==6: |
| 236 | + orientation_char = self.get_orientation_char(orientation) |
| 237 | + agent_color = '\x1b[1;42;42m' if state.carrying_apple else '\x1b[0m' |
| 238 | + print(agent_color+orientation_char+black_color, end='') |
| 239 | + print('') |
| 240 | + |
| 241 | + def get_orientation_char(self, orientation): |
| 242 | + DIRECTION_TO_CHAR = { |
| 243 | + Direction.NORTH: '↑', |
| 244 | + Direction.SOUTH: '↓', |
| 245 | + Direction.WEST: '←', |
| 246 | + Direction.EAST: '→', |
| 247 | + Direction.STAY: '*' |
| 248 | + } |
| 249 | + direction = Direction.get_direction_from_number(orientation) |
| 250 | + return DIRECTION_TO_CHAR[direction] |
0 commit comments