Skip to content

Commit

Permalink
[Go] Refactor naming (#1132)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 28, 2023
1 parent fe993ff commit 6868516
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 91 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _make_go_dwg(dwg, state: GoState, config):
board_g.add(hoshi_g)

# stones
board = jnp.clip(state._x._chain_id_board, -1, 1)
board = jnp.clip(state._x.chain_id_board, -1, 1)
for xy, stone in enumerate(board):
if stone == 0:
continue
Expand Down
124 changes: 62 additions & 62 deletions pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,30 @@

@dataclass
class GameState:
_size: Array = jnp.int32(19)
size: Array = jnp.int32(19)
# ids of representative stone id (smallest) in the connected stones
# positive for black, negative for white, and zero for empty.
_chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32)
_board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32)
_turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn
_num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w]
_consecutive_pass_count: Array = jnp.int32(0)
_ko: Array = jnp.int32(-1) # by SSK
_komi: Array = jnp.float32(7.5)
_is_psk: Array = FALSE
chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32)
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32)
turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn
num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w]
consecutive_pass_count: Array = jnp.int32(0)
ko: Array = jnp.int32(-1) # by SSK
komi: Array = jnp.float32(7.5)
is_psk: Array = FALSE


def init(size: int, komi: float) -> GameState:
return GameState(
_size=jnp.int32(size),
_chain_id_board=jnp.zeros(size**2, dtype=jnp.int32),
_board_history=jnp.full((8, size**2), 2, dtype=jnp.int32),
_komi=jnp.float32(komi),
size=jnp.int32(size),
chain_id_board=jnp.zeros(size**2, dtype=jnp.int32),
board_history=jnp.full((8, size**2), 2, dtype=jnp.int32),
komi=jnp.float32(komi),
)


def step(x: GameState, action: int, size: int) -> GameState:
x = x.replace(_ko=jnp.int32(-1)) # type: ignore
x = x.replace(ko=jnp.int32(-1)) # type: ignore

# update state
x = jax.lax.cond(
Expand All @@ -59,17 +59,17 @@ def step(x: GameState, action: int, size: int) -> GameState:
)

# increment turns
x = x.replace(_turn=(x._turn + 1) % 2) # type: ignore
x = x.replace(turn=(x.turn + 1) % 2) # type: ignore

# update board history
board_history = jnp.roll(x._board_history, size**2)
board_history = jnp.roll(x.board_history, size**2)
board_history = board_history.at[0].set(
jnp.clip(x._chain_id_board, -1, 1).astype(jnp.int32)
jnp.clip(x.chain_id_board, -1, 1).astype(jnp.int32)
)
x = x.replace(_board_history=board_history) # type: ignore
x = x.replace(board_history=board_history) # type: ignore

# check PSK
x = x.replace(_is_psk=_check_PSK(x)) # type: ignore
x = x.replace(is_psk=_check_PSK(x)) # type: ignore

return x

Expand All @@ -80,7 +80,7 @@ def observe(x: GameState, my_turn, size, history_length):
@jax.vmap
def _make(i):
color = jnp.int32([1, -1])[i % 2] * my_color
return x._board_history[i // 2] == color
return x.board_history[i // 2] == color

log = _make(jnp.arange(history_length * 2))
color = jnp.full_like(log[0], my_turn) # black=0, white=1
Expand All @@ -90,18 +90,18 @@ def _make(i):

def legal_action_mask(state: GameState, size: int) -> Array:
"""Logic is highly inspired by OpenSpiel's Go implementation"""
is_empty = state._chain_id_board == 0
is_empty = state.chain_id_board == 0

my_color = _my_color(state)
opp_color = _opponent_color(state)
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)

chain_ix = jnp.abs(state._chain_id_board) - 1
chain_ix = jnp.abs(state.chain_id_board) - 1
# fmt: off
in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix]
# fmt: on
has_liberty = (state._chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state._chain_id_board * opp_color > 0) & in_atari
has_liberty = (state.chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state.chain_id_board * opp_color > 0) & in_atari

@jax.vmap
def is_neighbor_ok(xy):
Expand All @@ -120,47 +120,47 @@ def is_neighbor_ok(xy):
legal_action_mask = is_empty & neighbor_ok

legal_action_mask = jax.lax.cond(
(state._ko == -1),
(state.ko == -1),
lambda: legal_action_mask,
lambda: legal_action_mask.at[state._ko].set(FALSE),
lambda: legal_action_mask.at[state.ko].set(FALSE),
)
return jnp.append(legal_action_mask, TRUE) # pass is always legal


def is_terminal(x: GameState):
two_consecutive_pass = x._consecutive_pass_count >= 2
return two_consecutive_pass | x._is_psk
two_consecutive_pass = x.consecutive_pass_count >= 2
return two_consecutive_pass | x.is_psk


def terminal_values(x: GameState, size: int):
score = _count_point(x, size)
reward_bw = jax.lax.select(
score[0] - x._komi > score[1],
score[0] - x.komi > score[1],
jnp.array([1, -1], dtype=jnp.float32),
jnp.array([-1, 1], dtype=jnp.float32),
)
to_play = x._turn
to_play = x.turn
reward_bw = jax.lax.select(
x._is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw
x.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw
)
return reward_bw


def _pass_move(state: GameState) -> GameState:
return state.replace(_consecutive_pass_count=state._consecutive_pass_count + 1) # type: ignore
return state.replace(consecutive_pass_count=state.consecutive_pass_count + 1) # type: ignore


def _not_pass_move(state: GameState, action, size) -> GameState:
state = state.replace(_consecutive_pass_count=0) # type: ignore
state = state.replace(consecutive_pass_count=0) # type: ignore
xy = action
num_captured_stones_before = state._num_captured_stones[state._turn]
num_captured_stones_before = state.num_captured_stones[state.turn]

ko_may_occur = _ko_may_occur(state, xy)

# Remove killed stones
adj_xy = _neighbour(xy, size)
oppo_color = _opponent_color(state)
chain_id = state._chain_id_board[adj_xy]
chain_id = state.chain_id_board[adj_xy]
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)
chain_ix = jnp.abs(chain_id) - 1
is_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[
Expand Down Expand Up @@ -193,9 +193,9 @@ def _not_pass_move(state: GameState, action, size) -> GameState:
# Check Ko
# fmt: off
state = jax.lax.cond(
state._num_captured_stones[state._turn] - num_captured_stones_before == 1,
state.num_captured_stones[state.turn] - num_captured_stones_before == 1,
lambda: state,
lambda: state.replace(_ko=jnp.int32(-1)) # type:ignore
lambda: state.replace(ko=jnp.int32(-1)) # type:ignore
)
# fmt: on

Expand All @@ -206,7 +206,7 @@ def _merge_around_xy(i, state: GameState, xy, size):
my_color = _my_color(state)
adj_xy = _neighbour(xy, size)[i]
is_off = adj_xy == -1
is_my_chain = state._chain_id_board[adj_xy] * my_color > 0
is_my_chain = state.chain_id_board[adj_xy] * my_color > 0
state = jax.lax.cond(
((~is_off) & is_my_chain),
lambda: _merge_chain(state, xy, adj_xy),
Expand All @@ -218,50 +218,50 @@ def _merge_around_xy(i, state: GameState, xy, size):
def _set_stone(state: GameState, xy) -> GameState:
my_color = _my_color(state)
return state.replace( # type: ignore
_chain_id_board=state._chain_id_board.at[xy].set((xy + 1) * my_color),
chain_id_board=state.chain_id_board.at[xy].set((xy + 1) * my_color),
)


def _merge_chain(state: GameState, xy, adj_xy):
my_color = _my_color(state)
new_id = jnp.abs(state._chain_id_board[xy])
adj_chain_id = jnp.abs(state._chain_id_board[adj_xy])
new_id = jnp.abs(state.chain_id_board[xy])
adj_chain_id = jnp.abs(state.chain_id_board[adj_xy])
small_id = jnp.minimum(new_id, adj_chain_id) * my_color
large_id = jnp.maximum(new_id, adj_chain_id) * my_color

# Keep larger chain ID and connect to the chain with smaller ID
chain_id_board = jnp.where(
state._chain_id_board == large_id,
state.chain_id_board == large_id,
small_id,
state._chain_id_board,
state.chain_id_board,
)

return state.replace(_chain_id_board=chain_id_board) # type: ignore
return state.replace(chain_id_board=chain_id_board) # type: ignore


def _remove_stones(
state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur
) -> GameState:
surrounded_stones = state._chain_id_board == rm_chain_id
surrounded_stones = state.chain_id_board == rm_chain_id
num_captured_stones = jnp.count_nonzero(surrounded_stones)
chain_id_board = jnp.where(surrounded_stones, 0, state._chain_id_board)
chain_id_board = jnp.where(surrounded_stones, 0, state.chain_id_board)
ko = jax.lax.cond(
ko_may_occur & (num_captured_stones == 1),
lambda: jnp.int32(rm_stone_xy),
lambda: state._ko,
lambda: state.ko,
)
return state.replace( # type: ignore
_chain_id_board=chain_id_board,
_num_captured_stones=state._num_captured_stones.at[state._turn].add(
chain_id_board=chain_id_board,
num_captured_stones=state.num_captured_stones.at[state.turn].add(
num_captured_stones
),
_ko=ko,
ko=ko,
)


def _count(state: GameState, size):
ZERO = jnp.int32(0)
chain_id_board = jnp.abs(state._chain_id_board)
chain_id_board = jnp.abs(state.chain_id_board)
is_empty = chain_id_board == 0
idx_sum = jnp.where(is_empty, jnp.arange(1, size**2 + 1), ZERO)
idx_squared_sum = jnp.where(
Expand Down Expand Up @@ -299,21 +299,21 @@ def _idx_squared_sum(x):


def _my_color(state: GameState):
return jnp.int32([1, -1])[state._turn]
return jnp.int32([1, -1])[state.turn]


def _opponent_color(state: GameState):
return jnp.int32([-1, 1])[state._turn]
return jnp.int32([-1, 1])[state.turn]


def _ko_may_occur(state: GameState, xy: int) -> Array:
size = state._size
size = state.size
x = xy // size
y = xy % size
oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size])
oppo_color = _opponent_color(state)
is_occupied_by_opp = (
state._chain_id_board[_neighbour(xy, size)] * oppo_color > 0
state.chain_id_board[_neighbour(xy, size)] * oppo_color > 0
)
return (oob | is_occupied_by_opp).all()

Expand Down Expand Up @@ -359,8 +359,8 @@ def _check_PSK(state: GameState):
Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board.
"""
# fmt: off
not_passed = state._consecutive_pass_count == 0
is_psk = not_passed & (jnp.abs(state._board_history[0] - state._board_history[1:]).sum(axis=1) == 0).any()
not_passed = state.consecutive_pass_count == 0
is_psk = not_passed & (jnp.abs(state.board_history[0] - state.board_history[1:]).sum(axis=1) == 0).any()
# fmt: on
return is_psk

Expand All @@ -369,18 +369,18 @@ def _count_point(state: GameState, size):
return jnp.array(
[
_count_ji(state, 1, size)
+ jnp.count_nonzero(state._chain_id_board > 0),
+ jnp.count_nonzero(state.chain_id_board > 0),
_count_ji(state, -1, size)
+ jnp.count_nonzero(state._chain_id_board < 0),
+ jnp.count_nonzero(state.chain_id_board < 0),
],
dtype=jnp.float32,
)


def _count_ji(state: GameState, color: int, size: int):
board = jnp.zeros_like(state._chain_id_board)
board = jnp.where(state._chain_id_board * color > 0, 1, board)
board = jnp.where(state._chain_id_board * color < 0, -1, board)
board = jnp.zeros_like(state.chain_id_board)
board = jnp.where(state.chain_id_board * color > 0, 1, board)
board = jnp.where(state.chain_id_board * color < 0, -1, board)
# 0 = empty, 1 = mine, -1 = opponent's

neighbours = _neighbours(size)
Expand Down
8 changes: 4 additions & 4 deletions pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,17 @@ def _set_config_by_state(self, _state: State): # noqa: C901
self.config["GRID_SIZE"] = 25
try:
self.config["BOARD_WIDTH"] = int(
_state._x._size[0] # type:ignore
_state._x.size[0] # type:ignore
)
self.config["BOARD_HEIGHT"] = int(
_state._x._size[0] # type:ignore
_state._x.size[0] # type:ignore
)
except IndexError:
self.config["BOARD_WIDTH"] = int(
_state._x._size # type: ignore
_state._x.size # type: ignore
) # type:ignore
self.config["BOARD_HEIGHT"] = int(
_state._x._size # type: ignore
_state._x.size # type: ignore
) # type:ignore
self._make_dwg_group = _make_go_dwg # type:ignore
if (
Expand Down
18 changes: 9 additions & 9 deletions pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class State(core.State):
@property
def env_id(self) -> core.EnvId:
try:
size = int(self._x._size.item())
size = int(self._x.size.item())
except TypeError:
size = int(self._x._size[0].item())
size = int(self._x.size[0].item())
return f"go_{size}x{size}" # type: ignore

@staticmethod
Expand Down Expand Up @@ -90,7 +90,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
# fmt: on
assert isinstance(state, State)
reward_bw = go.terminal_values(state._x, self.size)
should_flip = state.current_player == state._x._turn
should_flip = state.current_player == state._x.turn
rewards = jax.lax.select(should_flip, reward_bw, jnp.flip(reward_bw))
rewards = jax.lax.select(
state.terminated, rewards, jnp.zeros_like(rewards)
Expand Down Expand Up @@ -131,8 +131,8 @@ def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
my_turn = jax.lax.select(
player_id == state.current_player,
state._x._turn,
1 - state._x._turn,
state._x.turn,
1 - state._x.turn,
)
return go.observe(state._x, my_turn, self.size, self.history_length)

Expand All @@ -155,15 +155,15 @@ def _show(state: State) -> None:
WHITE_CHAR = "O"
POINT_CHAR = "+"
print("===========")
for xy in range(state._x._size * state._x._size):
if state._x._chain_id_board[xy] > 0:
for xy in range(state._x.size * state._x.size):
if state._x.chain_id_board[xy] > 0:
print(" " + BLACK_CHAR, end="")
elif state._x._chain_id_board[xy] < 0:
elif state._x.chain_id_board[xy] < 0:
print(" " + WHITE_CHAR, end="")
else:
print(" " + POINT_CHAR, end="")

if xy % state._x._size == state._x._size - 1:
if xy % state._x.size == state._x.size - 1:
print()


Expand Down
Loading

0 comments on commit 6868516

Please sign in to comment.