Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] Refactor naming #1132

Merged
merged 11 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgx/_src/dwg/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _make_go_dwg(dwg, state: GoState, config):
board_g.add(hoshi_g)

# stones
board = jnp.clip(state._x._chain_id_board, -1, 1)
board = jnp.clip(state._x.chain_id_board, -1, 1)
for xy, stone in enumerate(board):
if stone == 0:
continue
Expand Down
124 changes: 62 additions & 62 deletions pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,30 @@

@dataclass
class GameState:
_size: Array = jnp.int32(19)
size: Array = jnp.int32(19)
# ids of representative stone id (smallest) in the connected stones
# positive for black, negative for white, and zero for empty.
_chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32)
_board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32)
_turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn
_num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w]
_consecutive_pass_count: Array = jnp.int32(0)
_ko: Array = jnp.int32(-1) # by SSK
_komi: Array = jnp.float32(7.5)
_is_psk: Array = FALSE
chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32)
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32)
turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn
num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w]
consecutive_pass_count: Array = jnp.int32(0)
ko: Array = jnp.int32(-1) # by SSK
komi: Array = jnp.float32(7.5)
is_psk: Array = FALSE


def init(size: int, komi: float) -> GameState:
return GameState(
_size=jnp.int32(size),
_chain_id_board=jnp.zeros(size**2, dtype=jnp.int32),
_board_history=jnp.full((8, size**2), 2, dtype=jnp.int32),
_komi=jnp.float32(komi),
size=jnp.int32(size),
chain_id_board=jnp.zeros(size**2, dtype=jnp.int32),
board_history=jnp.full((8, size**2), 2, dtype=jnp.int32),
komi=jnp.float32(komi),
)


def step(x: GameState, action: int, size: int) -> GameState:
x = x.replace(_ko=jnp.int32(-1)) # type: ignore
x = x.replace(ko=jnp.int32(-1)) # type: ignore

# update state
x = jax.lax.cond(
Expand All @@ -59,17 +59,17 @@ def step(x: GameState, action: int, size: int) -> GameState:
)

# increment turns
x = x.replace(_turn=(x._turn + 1) % 2) # type: ignore
x = x.replace(turn=(x.turn + 1) % 2) # type: ignore

# update board history
board_history = jnp.roll(x._board_history, size**2)
board_history = jnp.roll(x.board_history, size**2)
board_history = board_history.at[0].set(
jnp.clip(x._chain_id_board, -1, 1).astype(jnp.int32)
jnp.clip(x.chain_id_board, -1, 1).astype(jnp.int32)
)
x = x.replace(_board_history=board_history) # type: ignore
x = x.replace(board_history=board_history) # type: ignore

# check PSK
x = x.replace(_is_psk=_check_PSK(x)) # type: ignore
x = x.replace(is_psk=_check_PSK(x)) # type: ignore

return x

Expand All @@ -80,7 +80,7 @@ def observe(x: GameState, my_turn, size, history_length):
@jax.vmap
def _make(i):
color = jnp.int32([1, -1])[i % 2] * my_color
return x._board_history[i // 2] == color
return x.board_history[i // 2] == color

log = _make(jnp.arange(history_length * 2))
color = jnp.full_like(log[0], my_turn) # black=0, white=1
Expand All @@ -90,18 +90,18 @@ def _make(i):

def legal_action_mask(state: GameState, size: int) -> Array:
"""Logic is highly inspired by OpenSpiel's Go implementation"""
is_empty = state._chain_id_board == 0
is_empty = state.chain_id_board == 0

my_color = _my_color(state)
opp_color = _opponent_color(state)
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)

chain_ix = jnp.abs(state._chain_id_board) - 1
chain_ix = jnp.abs(state.chain_id_board) - 1
# fmt: off
in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix]
# fmt: on
has_liberty = (state._chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state._chain_id_board * opp_color > 0) & in_atari
has_liberty = (state.chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state.chain_id_board * opp_color > 0) & in_atari

@jax.vmap
def is_neighbor_ok(xy):
Expand All @@ -120,47 +120,47 @@ def is_neighbor_ok(xy):
legal_action_mask = is_empty & neighbor_ok

legal_action_mask = jax.lax.cond(
(state._ko == -1),
(state.ko == -1),
lambda: legal_action_mask,
lambda: legal_action_mask.at[state._ko].set(FALSE),
lambda: legal_action_mask.at[state.ko].set(FALSE),
)
return jnp.append(legal_action_mask, TRUE) # pass is always legal


def is_terminal(x: GameState):
two_consecutive_pass = x._consecutive_pass_count >= 2
return two_consecutive_pass | x._is_psk
two_consecutive_pass = x.consecutive_pass_count >= 2
return two_consecutive_pass | x.is_psk


def terminal_values(x: GameState, size: int):
score = _count_point(x, size)
reward_bw = jax.lax.select(
score[0] - x._komi > score[1],
score[0] - x.komi > score[1],
jnp.array([1, -1], dtype=jnp.float32),
jnp.array([-1, 1], dtype=jnp.float32),
)
to_play = x._turn
to_play = x.turn
reward_bw = jax.lax.select(
x._is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw
x.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw
)
return reward_bw


def _pass_move(state: GameState) -> GameState:
return state.replace(_consecutive_pass_count=state._consecutive_pass_count + 1) # type: ignore
return state.replace(consecutive_pass_count=state.consecutive_pass_count + 1) # type: ignore


def _not_pass_move(state: GameState, action, size) -> GameState:
state = state.replace(_consecutive_pass_count=0) # type: ignore
state = state.replace(consecutive_pass_count=0) # type: ignore
xy = action
num_captured_stones_before = state._num_captured_stones[state._turn]
num_captured_stones_before = state.num_captured_stones[state.turn]

ko_may_occur = _ko_may_occur(state, xy)

# Remove killed stones
adj_xy = _neighbour(xy, size)
oppo_color = _opponent_color(state)
chain_id = state._chain_id_board[adj_xy]
chain_id = state.chain_id_board[adj_xy]
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)
chain_ix = jnp.abs(chain_id) - 1
is_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[
Expand Down Expand Up @@ -193,9 +193,9 @@ def _not_pass_move(state: GameState, action, size) -> GameState:
# Check Ko
# fmt: off
state = jax.lax.cond(
state._num_captured_stones[state._turn] - num_captured_stones_before == 1,
state.num_captured_stones[state.turn] - num_captured_stones_before == 1,
lambda: state,
lambda: state.replace(_ko=jnp.int32(-1)) # type:ignore
lambda: state.replace(ko=jnp.int32(-1)) # type:ignore
)
# fmt: on

Expand All @@ -206,7 +206,7 @@ def _merge_around_xy(i, state: GameState, xy, size):
my_color = _my_color(state)
adj_xy = _neighbour(xy, size)[i]
is_off = adj_xy == -1
is_my_chain = state._chain_id_board[adj_xy] * my_color > 0
is_my_chain = state.chain_id_board[adj_xy] * my_color > 0
state = jax.lax.cond(
((~is_off) & is_my_chain),
lambda: _merge_chain(state, xy, adj_xy),
Expand All @@ -218,50 +218,50 @@ def _merge_around_xy(i, state: GameState, xy, size):
def _set_stone(state: GameState, xy) -> GameState:
my_color = _my_color(state)
return state.replace( # type: ignore
_chain_id_board=state._chain_id_board.at[xy].set((xy + 1) * my_color),
chain_id_board=state.chain_id_board.at[xy].set((xy + 1) * my_color),
)


def _merge_chain(state: GameState, xy, adj_xy):
my_color = _my_color(state)
new_id = jnp.abs(state._chain_id_board[xy])
adj_chain_id = jnp.abs(state._chain_id_board[adj_xy])
new_id = jnp.abs(state.chain_id_board[xy])
adj_chain_id = jnp.abs(state.chain_id_board[adj_xy])
small_id = jnp.minimum(new_id, adj_chain_id) * my_color
large_id = jnp.maximum(new_id, adj_chain_id) * my_color

# Keep larger chain ID and connect to the chain with smaller ID
chain_id_board = jnp.where(
state._chain_id_board == large_id,
state.chain_id_board == large_id,
small_id,
state._chain_id_board,
state.chain_id_board,
)

return state.replace(_chain_id_board=chain_id_board) # type: ignore
return state.replace(chain_id_board=chain_id_board) # type: ignore


def _remove_stones(
state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur
) -> GameState:
surrounded_stones = state._chain_id_board == rm_chain_id
surrounded_stones = state.chain_id_board == rm_chain_id
num_captured_stones = jnp.count_nonzero(surrounded_stones)
chain_id_board = jnp.where(surrounded_stones, 0, state._chain_id_board)
chain_id_board = jnp.where(surrounded_stones, 0, state.chain_id_board)
ko = jax.lax.cond(
ko_may_occur & (num_captured_stones == 1),
lambda: jnp.int32(rm_stone_xy),
lambda: state._ko,
lambda: state.ko,
)
return state.replace( # type: ignore
_chain_id_board=chain_id_board,
_num_captured_stones=state._num_captured_stones.at[state._turn].add(
chain_id_board=chain_id_board,
num_captured_stones=state.num_captured_stones.at[state.turn].add(
num_captured_stones
),
_ko=ko,
ko=ko,
)


def _count(state: GameState, size):
ZERO = jnp.int32(0)
chain_id_board = jnp.abs(state._chain_id_board)
chain_id_board = jnp.abs(state.chain_id_board)
is_empty = chain_id_board == 0
idx_sum = jnp.where(is_empty, jnp.arange(1, size**2 + 1), ZERO)
idx_squared_sum = jnp.where(
Expand Down Expand Up @@ -299,21 +299,21 @@ def _idx_squared_sum(x):


def _my_color(state: GameState):
return jnp.int32([1, -1])[state._turn]
return jnp.int32([1, -1])[state.turn]


def _opponent_color(state: GameState):
return jnp.int32([-1, 1])[state._turn]
return jnp.int32([-1, 1])[state.turn]


def _ko_may_occur(state: GameState, xy: int) -> Array:
size = state._size
size = state.size
x = xy // size
y = xy % size
oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size])
oppo_color = _opponent_color(state)
is_occupied_by_opp = (
state._chain_id_board[_neighbour(xy, size)] * oppo_color > 0
state.chain_id_board[_neighbour(xy, size)] * oppo_color > 0
)
return (oob | is_occupied_by_opp).all()

Expand Down Expand Up @@ -359,8 +359,8 @@ def _check_PSK(state: GameState):
Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board.
"""
# fmt: off
not_passed = state._consecutive_pass_count == 0
is_psk = not_passed & (jnp.abs(state._board_history[0] - state._board_history[1:]).sum(axis=1) == 0).any()
not_passed = state.consecutive_pass_count == 0
is_psk = not_passed & (jnp.abs(state.board_history[0] - state.board_history[1:]).sum(axis=1) == 0).any()
# fmt: on
return is_psk

Expand All @@ -369,18 +369,18 @@ def _count_point(state: GameState, size):
return jnp.array(
[
_count_ji(state, 1, size)
+ jnp.count_nonzero(state._chain_id_board > 0),
+ jnp.count_nonzero(state.chain_id_board > 0),
_count_ji(state, -1, size)
+ jnp.count_nonzero(state._chain_id_board < 0),
+ jnp.count_nonzero(state.chain_id_board < 0),
],
dtype=jnp.float32,
)


def _count_ji(state: GameState, color: int, size: int):
board = jnp.zeros_like(state._chain_id_board)
board = jnp.where(state._chain_id_board * color > 0, 1, board)
board = jnp.where(state._chain_id_board * color < 0, -1, board)
board = jnp.zeros_like(state.chain_id_board)
board = jnp.where(state.chain_id_board * color > 0, 1, board)
board = jnp.where(state.chain_id_board * color < 0, -1, board)
# 0 = empty, 1 = mine, -1 = opponent's

neighbours = _neighbours(size)
Expand Down
8 changes: 4 additions & 4 deletions pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,17 @@ def _set_config_by_state(self, _state: State): # noqa: C901
self.config["GRID_SIZE"] = 25
try:
self.config["BOARD_WIDTH"] = int(
_state._x._size[0] # type:ignore
_state._x.size[0] # type:ignore
)
self.config["BOARD_HEIGHT"] = int(
_state._x._size[0] # type:ignore
_state._x.size[0] # type:ignore
)
except IndexError:
self.config["BOARD_WIDTH"] = int(
_state._x._size # type: ignore
_state._x.size # type: ignore
) # type:ignore
self.config["BOARD_HEIGHT"] = int(
_state._x._size # type: ignore
_state._x.size # type: ignore
) # type:ignore
self._make_dwg_group = _make_go_dwg # type:ignore
if (
Expand Down
18 changes: 9 additions & 9 deletions pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
@property
def env_id(self) -> core.EnvId:
try:
size = int(self._x._size.item())
size = int(self._x.size.item())
except TypeError:
size = int(self._x._size[0].item())
size = int(self._x.size[0].item())
return f"go_{size}x{size}" # type: ignore

@staticmethod
Expand Down Expand Up @@ -90,7 +90,7 @@
# fmt: on
assert isinstance(state, State)
reward_bw = go.terminal_values(state._x, self.size)
should_flip = state.current_player == state._x._turn
should_flip = state.current_player == state._x.turn
rewards = jax.lax.select(should_flip, reward_bw, jnp.flip(reward_bw))
rewards = jax.lax.select(
state.terminated, rewards, jnp.zeros_like(rewards)
Expand Down Expand Up @@ -131,8 +131,8 @@
assert isinstance(state, State)
my_turn = jax.lax.select(
player_id == state.current_player,
state._x._turn,
1 - state._x._turn,
state._x.turn,
1 - state._x.turn,
)
return go.observe(state._x, my_turn, self.size, self.history_length)

Expand All @@ -155,15 +155,15 @@
WHITE_CHAR = "O"
POINT_CHAR = "+"
print("===========")
for xy in range(state._x._size * state._x._size):
if state._x._chain_id_board[xy] > 0:
for xy in range(state._x.size * state._x.size):
if state._x.chain_id_board[xy] > 0:

Check warning on line 159 in pgx/go.py

View check run for this annotation

Codecov / codecov/patch

pgx/go.py#L158-L159

Added lines #L158 - L159 were not covered by tests
print(" " + BLACK_CHAR, end="")
elif state._x._chain_id_board[xy] < 0:
elif state._x.chain_id_board[xy] < 0:

Check warning on line 161 in pgx/go.py

View check run for this annotation

Codecov / codecov/patch

pgx/go.py#L161

Added line #L161 was not covered by tests
print(" " + WHITE_CHAR, end="")
else:
print(" " + POINT_CHAR, end="")

if xy % state._x._size == state._x._size - 1:
if xy % state._x.size == state._x.size - 1:

Check warning on line 166 in pgx/go.py

View check run for this annotation

Codecov / codecov/patch

pgx/go.py#L166

Added line #L166 was not covered by tests
print()


Expand Down
Loading
Loading