From 6868516a6e9748b4edccd017189b1fee7301a198 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Thu, 28 Dec 2023 13:32:58 +0900 Subject: [PATCH] [Go] Refactor naming (#1132) --- pgx/_src/dwg/go.py | 2 +- pgx/_src/games/go.py | 124 ++++++++++++++++++++--------------------- pgx/_src/visualizer.py | 8 +-- pgx/go.py | 18 +++--- tests/test_go.py | 30 +++++----- 5 files changed, 91 insertions(+), 91 deletions(-) diff --git a/pgx/_src/dwg/go.py b/pgx/_src/dwg/go.py index e0b7ba5c7..00c5dbc17 100644 --- a/pgx/_src/dwg/go.py +++ b/pgx/_src/dwg/go.py @@ -86,7 +86,7 @@ def _make_go_dwg(dwg, state: GoState, config): board_g.add(hoshi_g) # stones - board = jnp.clip(state._x._chain_id_board, -1, 1) + board = jnp.clip(state._x.chain_id_board, -1, 1) for xy, stone in enumerate(board): if stone == 0: continue diff --git a/pgx/_src/games/go.py b/pgx/_src/games/go.py index 4f09b4aee..6982084e7 100644 --- a/pgx/_src/games/go.py +++ b/pgx/_src/games/go.py @@ -26,30 +26,30 @@ @dataclass class GameState: - _size: Array = jnp.int32(19) + size: Array = jnp.int32(19) # ids of representative stone id (smallest) in the connected stones # positive for black, negative for white, and zero for empty. - _chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32) - _board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32) - _turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn - _num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w] - _consecutive_pass_count: Array = jnp.int32(0) - _ko: Array = jnp.int32(-1) # by SSK - _komi: Array = jnp.float32(7.5) - _is_psk: Array = FALSE + chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32) + board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32) + turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn + num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w] + consecutive_pass_count: Array = jnp.int32(0) + ko: Array = jnp.int32(-1) # by SSK + komi: Array = jnp.float32(7.5) + is_psk: Array = FALSE def init(size: int, komi: float) -> GameState: return GameState( - _size=jnp.int32(size), - _chain_id_board=jnp.zeros(size**2, dtype=jnp.int32), - _board_history=jnp.full((8, size**2), 2, dtype=jnp.int32), - _komi=jnp.float32(komi), + size=jnp.int32(size), + chain_id_board=jnp.zeros(size**2, dtype=jnp.int32), + board_history=jnp.full((8, size**2), 2, dtype=jnp.int32), + komi=jnp.float32(komi), ) def step(x: GameState, action: int, size: int) -> GameState: - x = x.replace(_ko=jnp.int32(-1)) # type: ignore + x = x.replace(ko=jnp.int32(-1)) # type: ignore # update state x = jax.lax.cond( @@ -59,17 +59,17 @@ def step(x: GameState, action: int, size: int) -> GameState: ) # increment turns - x = x.replace(_turn=(x._turn + 1) % 2) # type: ignore + x = x.replace(turn=(x.turn + 1) % 2) # type: ignore # update board history - board_history = jnp.roll(x._board_history, size**2) + board_history = jnp.roll(x.board_history, size**2) board_history = board_history.at[0].set( - jnp.clip(x._chain_id_board, -1, 1).astype(jnp.int32) + jnp.clip(x.chain_id_board, -1, 1).astype(jnp.int32) ) - x = x.replace(_board_history=board_history) # type: ignore + x = x.replace(board_history=board_history) # type: ignore # check PSK - x = x.replace(_is_psk=_check_PSK(x)) # type: ignore + x = x.replace(is_psk=_check_PSK(x)) # type: ignore return x @@ -80,7 +80,7 @@ def observe(x: GameState, my_turn, size, history_length): @jax.vmap def _make(i): color = jnp.int32([1, -1])[i % 2] * my_color - return x._board_history[i // 2] == color + return x.board_history[i // 2] == color log = _make(jnp.arange(history_length * 2)) color = jnp.full_like(log[0], my_turn) # black=0, white=1 @@ -90,18 +90,18 @@ def _make(i): def legal_action_mask(state: GameState, size: int) -> Array: """Logic is highly inspired by OpenSpiel's Go implementation""" - is_empty = state._chain_id_board == 0 + is_empty = state.chain_id_board == 0 my_color = _my_color(state) opp_color = _opponent_color(state) num_pseudo, idx_sum, idx_squared_sum = _count(state, size) - chain_ix = jnp.abs(state._chain_id_board) - 1 + chain_ix = jnp.abs(state.chain_id_board) - 1 # fmt: off in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix] # fmt: on - has_liberty = (state._chain_id_board * my_color > 0) & ~in_atari - kills_opp = (state._chain_id_board * opp_color > 0) & in_atari + has_liberty = (state.chain_id_board * my_color > 0) & ~in_atari + kills_opp = (state.chain_id_board * opp_color > 0) & in_atari @jax.vmap def is_neighbor_ok(xy): @@ -120,47 +120,47 @@ def is_neighbor_ok(xy): legal_action_mask = is_empty & neighbor_ok legal_action_mask = jax.lax.cond( - (state._ko == -1), + (state.ko == -1), lambda: legal_action_mask, - lambda: legal_action_mask.at[state._ko].set(FALSE), + lambda: legal_action_mask.at[state.ko].set(FALSE), ) return jnp.append(legal_action_mask, TRUE) # pass is always legal def is_terminal(x: GameState): - two_consecutive_pass = x._consecutive_pass_count >= 2 - return two_consecutive_pass | x._is_psk + two_consecutive_pass = x.consecutive_pass_count >= 2 + return two_consecutive_pass | x.is_psk def terminal_values(x: GameState, size: int): score = _count_point(x, size) reward_bw = jax.lax.select( - score[0] - x._komi > score[1], + score[0] - x.komi > score[1], jnp.array([1, -1], dtype=jnp.float32), jnp.array([-1, 1], dtype=jnp.float32), ) - to_play = x._turn + to_play = x.turn reward_bw = jax.lax.select( - x._is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw + x.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw ) return reward_bw def _pass_move(state: GameState) -> GameState: - return state.replace(_consecutive_pass_count=state._consecutive_pass_count + 1) # type: ignore + return state.replace(consecutive_pass_count=state.consecutive_pass_count + 1) # type: ignore def _not_pass_move(state: GameState, action, size) -> GameState: - state = state.replace(_consecutive_pass_count=0) # type: ignore + state = state.replace(consecutive_pass_count=0) # type: ignore xy = action - num_captured_stones_before = state._num_captured_stones[state._turn] + num_captured_stones_before = state.num_captured_stones[state.turn] ko_may_occur = _ko_may_occur(state, xy) # Remove killed stones adj_xy = _neighbour(xy, size) oppo_color = _opponent_color(state) - chain_id = state._chain_id_board[adj_xy] + chain_id = state.chain_id_board[adj_xy] num_pseudo, idx_sum, idx_squared_sum = _count(state, size) chain_ix = jnp.abs(chain_id) - 1 is_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[ @@ -193,9 +193,9 @@ def _not_pass_move(state: GameState, action, size) -> GameState: # Check Ko # fmt: off state = jax.lax.cond( - state._num_captured_stones[state._turn] - num_captured_stones_before == 1, + state.num_captured_stones[state.turn] - num_captured_stones_before == 1, lambda: state, - lambda: state.replace(_ko=jnp.int32(-1)) # type:ignore + lambda: state.replace(ko=jnp.int32(-1)) # type:ignore ) # fmt: on @@ -206,7 +206,7 @@ def _merge_around_xy(i, state: GameState, xy, size): my_color = _my_color(state) adj_xy = _neighbour(xy, size)[i] is_off = adj_xy == -1 - is_my_chain = state._chain_id_board[adj_xy] * my_color > 0 + is_my_chain = state.chain_id_board[adj_xy] * my_color > 0 state = jax.lax.cond( ((~is_off) & is_my_chain), lambda: _merge_chain(state, xy, adj_xy), @@ -218,50 +218,50 @@ def _merge_around_xy(i, state: GameState, xy, size): def _set_stone(state: GameState, xy) -> GameState: my_color = _my_color(state) return state.replace( # type: ignore - _chain_id_board=state._chain_id_board.at[xy].set((xy + 1) * my_color), + chain_id_board=state.chain_id_board.at[xy].set((xy + 1) * my_color), ) def _merge_chain(state: GameState, xy, adj_xy): my_color = _my_color(state) - new_id = jnp.abs(state._chain_id_board[xy]) - adj_chain_id = jnp.abs(state._chain_id_board[adj_xy]) + new_id = jnp.abs(state.chain_id_board[xy]) + adj_chain_id = jnp.abs(state.chain_id_board[adj_xy]) small_id = jnp.minimum(new_id, adj_chain_id) * my_color large_id = jnp.maximum(new_id, adj_chain_id) * my_color # Keep larger chain ID and connect to the chain with smaller ID chain_id_board = jnp.where( - state._chain_id_board == large_id, + state.chain_id_board == large_id, small_id, - state._chain_id_board, + state.chain_id_board, ) - return state.replace(_chain_id_board=chain_id_board) # type: ignore + return state.replace(chain_id_board=chain_id_board) # type: ignore def _remove_stones( state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur ) -> GameState: - surrounded_stones = state._chain_id_board == rm_chain_id + surrounded_stones = state.chain_id_board == rm_chain_id num_captured_stones = jnp.count_nonzero(surrounded_stones) - chain_id_board = jnp.where(surrounded_stones, 0, state._chain_id_board) + chain_id_board = jnp.where(surrounded_stones, 0, state.chain_id_board) ko = jax.lax.cond( ko_may_occur & (num_captured_stones == 1), lambda: jnp.int32(rm_stone_xy), - lambda: state._ko, + lambda: state.ko, ) return state.replace( # type: ignore - _chain_id_board=chain_id_board, - _num_captured_stones=state._num_captured_stones.at[state._turn].add( + chain_id_board=chain_id_board, + num_captured_stones=state.num_captured_stones.at[state.turn].add( num_captured_stones ), - _ko=ko, + ko=ko, ) def _count(state: GameState, size): ZERO = jnp.int32(0) - chain_id_board = jnp.abs(state._chain_id_board) + chain_id_board = jnp.abs(state.chain_id_board) is_empty = chain_id_board == 0 idx_sum = jnp.where(is_empty, jnp.arange(1, size**2 + 1), ZERO) idx_squared_sum = jnp.where( @@ -299,21 +299,21 @@ def _idx_squared_sum(x): def _my_color(state: GameState): - return jnp.int32([1, -1])[state._turn] + return jnp.int32([1, -1])[state.turn] def _opponent_color(state: GameState): - return jnp.int32([-1, 1])[state._turn] + return jnp.int32([-1, 1])[state.turn] def _ko_may_occur(state: GameState, xy: int) -> Array: - size = state._size + size = state.size x = xy // size y = xy % size oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size]) oppo_color = _opponent_color(state) is_occupied_by_opp = ( - state._chain_id_board[_neighbour(xy, size)] * oppo_color > 0 + state.chain_id_board[_neighbour(xy, size)] * oppo_color > 0 ) return (oob | is_occupied_by_opp).all() @@ -359,8 +359,8 @@ def _check_PSK(state: GameState): Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board. """ # fmt: off - not_passed = state._consecutive_pass_count == 0 - is_psk = not_passed & (jnp.abs(state._board_history[0] - state._board_history[1:]).sum(axis=1) == 0).any() + not_passed = state.consecutive_pass_count == 0 + is_psk = not_passed & (jnp.abs(state.board_history[0] - state.board_history[1:]).sum(axis=1) == 0).any() # fmt: on return is_psk @@ -369,18 +369,18 @@ def _count_point(state: GameState, size): return jnp.array( [ _count_ji(state, 1, size) - + jnp.count_nonzero(state._chain_id_board > 0), + + jnp.count_nonzero(state.chain_id_board > 0), _count_ji(state, -1, size) - + jnp.count_nonzero(state._chain_id_board < 0), + + jnp.count_nonzero(state.chain_id_board < 0), ], dtype=jnp.float32, ) def _count_ji(state: GameState, color: int, size: int): - board = jnp.zeros_like(state._chain_id_board) - board = jnp.where(state._chain_id_board * color > 0, 1, board) - board = jnp.where(state._chain_id_board * color < 0, -1, board) + board = jnp.zeros_like(state.chain_id_board) + board = jnp.where(state.chain_id_board * color > 0, 1, board) + board = jnp.where(state.chain_id_board * color < 0, -1, board) # 0 = empty, 1 = mine, -1 = opponent's neighbours = _neighbours(size) diff --git a/pgx/_src/visualizer.py b/pgx/_src/visualizer.py index 65090754c..f54ee8d89 100644 --- a/pgx/_src/visualizer.py +++ b/pgx/_src/visualizer.py @@ -372,17 +372,17 @@ def _set_config_by_state(self, _state: State): # noqa: C901 self.config["GRID_SIZE"] = 25 try: self.config["BOARD_WIDTH"] = int( - _state._x._size[0] # type:ignore + _state._x.size[0] # type:ignore ) self.config["BOARD_HEIGHT"] = int( - _state._x._size[0] # type:ignore + _state._x.size[0] # type:ignore ) except IndexError: self.config["BOARD_WIDTH"] = int( - _state._x._size # type: ignore + _state._x.size # type: ignore ) # type:ignore self.config["BOARD_HEIGHT"] = int( - _state._x._size # type: ignore + _state._x.size # type: ignore ) # type:ignore self._make_dwg_group = _make_go_dwg # type:ignore if ( diff --git a/pgx/go.py b/pgx/go.py index 57b06b54c..a6fa24dd7 100644 --- a/pgx/go.py +++ b/pgx/go.py @@ -38,9 +38,9 @@ class State(core.State): @property def env_id(self) -> core.EnvId: try: - size = int(self._x._size.item()) + size = int(self._x.size.item()) except TypeError: - size = int(self._x._size[0].item()) + size = int(self._x.size[0].item()) return f"go_{size}x{size}" # type: ignore @staticmethod @@ -90,7 +90,7 @@ def _step(self, state: core.State, action: Array, key) -> State: # fmt: on assert isinstance(state, State) reward_bw = go.terminal_values(state._x, self.size) - should_flip = state.current_player == state._x._turn + should_flip = state.current_player == state._x.turn rewards = jax.lax.select(should_flip, reward_bw, jnp.flip(reward_bw)) rewards = jax.lax.select( state.terminated, rewards, jnp.zeros_like(rewards) @@ -131,8 +131,8 @@ def _observe(self, state: core.State, player_id: Array) -> Array: assert isinstance(state, State) my_turn = jax.lax.select( player_id == state.current_player, - state._x._turn, - 1 - state._x._turn, + state._x.turn, + 1 - state._x.turn, ) return go.observe(state._x, my_turn, self.size, self.history_length) @@ -155,15 +155,15 @@ def _show(state: State) -> None: WHITE_CHAR = "O" POINT_CHAR = "+" print("===========") - for xy in range(state._x._size * state._x._size): - if state._x._chain_id_board[xy] > 0: + for xy in range(state._x.size * state._x.size): + if state._x.chain_id_board[xy] > 0: print(" " + BLACK_CHAR, end="") - elif state._x._chain_id_board[xy] < 0: + elif state._x.chain_id_board[xy] < 0: print(" " + WHITE_CHAR, end="") else: print(" " + POINT_CHAR, end="") - if xy % state._x._size == state._x._size - 1: + if xy % state._x.size == state._x.size - 1: print() diff --git a/tests/test_go.py b/tests/test_go.py index 7ec7c68fa..d5af10535 100644 --- a/tests/test_go.py +++ b/tests/test_go.py @@ -24,16 +24,16 @@ def test_end_by_pass(): state = init(key=key) state = step(state=state, action=25) - assert state._x._consecutive_pass_count == 1 + assert state._x.consecutive_pass_count == 1 assert not state.terminated state = step(state=state, action=0) - assert state._x._consecutive_pass_count == 0 + assert state._x.consecutive_pass_count == 0 assert not state.terminated state = step(state=state, action=25) - assert state._x._consecutive_pass_count == 1 + assert state._x.consecutive_pass_count == 1 assert not state.terminated state = step(state=state, action=25) - assert state._x._consecutive_pass_count == 2 + assert state._x.consecutive_pass_count == 2 assert state.terminated @@ -85,7 +85,7 @@ def test_step(): [3] O O @ + @ [4] O O O @ + """ - assert (jnp.clip(state._x._chain_id_board, -1, 1) == expected_board.ravel()).all() + assert (jnp.clip(state._x.chain_id_board, -1, 1) == expected_board.ravel()).all() assert state.terminated # 同点なのでコミの分 黒 == player_1 の負け @@ -108,7 +108,7 @@ def test_from_sgf(): [ 1, 0, 1, 1, -1, 0, 0, 0, 0], ] ) # type:ignore - assert (jnp.clip(state._x._chain_id_board, -1, 1) == expected_board.ravel()).all() + assert (jnp.clip(state._x.chain_id_board, -1, 1) == expected_board.ravel()).all() assert state.terminated @@ -121,7 +121,7 @@ def test_from_sgf(): # 初手からの分岐 state = State._from_sgf("(;FF[4]GM[1]CA[UTF-8]AP[besogo:0.0.0-alpha]SZ[9]ST[0](;B[ee])(;B[eg])(;B[ec]))") state.save_svg("tests/assets/go/from_sgf_003.svg") - board = jnp.clip(state._x._chain_id_board, -1, 1) + board = jnp.clip(state._x.chain_id_board, -1, 1) assert board[40] == 1 assert not state.terminated @@ -144,7 +144,7 @@ def test_from_sgf(): # 分岐あり state = State._from_sgf("(;FF[4]GM[1]CA[UTF-8]AP[besogo:0.0.0-alpha]SZ[19]ST[0];B[pd];W[qf];B[nc](;W[rd];B[qc];W[qi])(;W[qd];B[qc];W[rc];B[qe];W[rd];B[pf];W[re];B[pe];W[qg]))") state.save_svg("tests/assets/go/from_sgf_007.svg") - board = jnp.clip(state._x._chain_id_board, -1, 1) + board = jnp.clip(state._x.chain_id_board, -1, 1) assert board[168] == -1 assert board[55] == 0 @@ -184,7 +184,7 @@ def test_ko(): + + O + + + + + + + """ - assert state._x._ko == 12 + assert state._x.ko == 12 loser = state.current_player state1: State = step( @@ -198,7 +198,7 @@ def test_ko(): state2: State = step(state=state, action=0) # BLACK # 回避した場合 assert not state2.terminated - assert state2._x._ko == -1 + assert state2._x.ko == -1 # see #468 state: State = init(key=key) @@ -233,7 +233,7 @@ def test_ko(): state = step(state, action=14) state = step(state, action=23) state = step(state, action=0) - assert state._x._ko == -1 + assert state._x.ko == -1 # see #468 state: State = init(key=key) @@ -265,7 +265,7 @@ def test_ko(): state = step(state, action=25) state = step(state, action=3) state = step(state, action=20) - assert state._x._ko == -1 + assert state._x.ko == -1 # Ko after pass state: State = init(key=key) @@ -307,7 +307,7 @@ def test_ko(): state = step(state, action=13) state = step(state, action=24) state = step(state, action=25) # pass - assert state._x._ko == -1 + assert state._x.ko == -1 # see #479 actions = [107, 11, 56, 41, 300, 19, 228, 231, 344, 257, 35, 32, 57, 276, 0, 277, 164, 15, 187, 179, 357, 255, 150, 211, 256, @@ -333,7 +333,7 @@ def test_ko(): state = env.init(jax.random.PRNGKey(0)) for a in actions: state = env.step(state, a) - assert state._x._ko == -1 + assert state._x.ko == -1 assert state.legal_action_mask[231] def test_observe(): @@ -370,7 +370,7 @@ def test_observe(): ) # fmt: on assert state.current_player == 1 - assert state._x._turn % 2 == 0 # black turn + assert state._x.turn % 2 == 0 # black turn obs = observe(state, 0) # white assert obs.shape == (5, 5, 17) assert (obs[:, :, 0] == (curr_board == -1)).all()