Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hex] Extract logic functions #1288

Merged
merged 19 commits into from
Dec 2, 2024
6 changes: 5 additions & 1 deletion pgx/_src/dwg/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import jax.numpy as jnp

from pgx.hex import State as HexState
from pgx.hex import _get_abs_board


def _get_abs_board(state):
return state._x.board if state._x.color == 0 else state._x.board * -1


r3 = jnp.sqrt(3)

Expand Down
110 changes: 53 additions & 57 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import partial
from typing import NamedTuple
from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp
Expand All @@ -29,7 +29,7 @@
class GameState(NamedTuple):
size: Array = jnp.int32(11)
# 0(black), 1(white)
turn: Array = jnp.int32(0)
step_count: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
Expand All @@ -39,6 +39,10 @@
# [110, 111, 112, ..., 119, 120]]
board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)

@property
def color(self) -> Array:
return self.step_count % 2


@dataclass
class State(core.State):
Expand All @@ -63,20 +67,37 @@
self.size = size

def _init(self, key: PRNGKey) -> State:
return partial(_init, size=self.size)(rng=key)
current_player = jnp.int32(jax.random.bernoulli(key))
return State(_x=_init(self.size), current_player=current_player) # type:ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
return jax.lax.cond(
x = jax.lax.cond(
action != self.size * self.size,
lambda: partial(_step, size=self.size)(state, action),
lambda: partial(_swap, size=self.size)(state),
lambda: partial(_step, size=self.size)(state._x, action),
lambda: partial(_swap, size=self.size)(state._x),
)

terminated = _is_terminal(x, self.size)
reward = jax.lax.cond(
terminated,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
)

return state.replace( # type:ignore
current_player=1 - state.current_player,
legal_action_mask=state.legal_action_mask.at[:-1].set(state._x.board == 0).at[-1].set(state._step_count == 1),
rewards=reward,
terminated=terminated,
_x=x,
)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
return partial(_observe, size=self.size)(state, player_id)
color = jax.lax.select(player_id == state.current_player, state._x.color, 1 - state._x.color)
return _observe(state._x, color, self.size)

@property
def id(self) -> core.EnvId:
Expand All @@ -91,14 +112,13 @@
return 2


def _init(rng: PRNGKey, size: int) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(_x=GameState(size=size), current_player=current_player) # type:ignore
def _init(size: int) -> GameState:
return GameState(size=size)


def _step(state: State, action: Array, size: int) -> State:
def _step(state: GameState, action: Array, size: int) -> GameState:
set_place_id = action + 1
board = state._x.board.at[action].set(set_place_id)
board = state.board.at[action].set(set_place_id)
neighbour = _neighbour(action, size)

def merge(i, b):
Expand All @@ -110,57 +130,37 @@
)

board = jax.lax.fori_loop(0, 6, merge, board)
won = _is_game_end(board, size, state._x.turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
return state._replace(
step_count=state.step_count + 1,
board=board * -1,
)

state = state.replace( # type:ignore
current_player=1 - state.current_player,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
rewards=reward,
terminated=won,
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(state._step_count == 1),
)

return state


def _swap(state: State, size: int) -> State:
ix = jnp.nonzero(state._x.board, size=1)[0]
def _swap(state: GameState, size: int) -> GameState:
ix = jnp.nonzero(state.board, size=1)[0]
row = ix // size
col = ix % size
swapped_ix = col * size + row
set_place_id = swapped_ix + 1
board = state._x.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state.replace( # type: ignore
current_player=1 - state.current_player,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(FALSE),
board = state.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state._replace(
step_count=state.step_count + 1,
board=board * -1,
)


def _observe(state: State, player_id: Array, size) -> Array:
board = jax.lax.select(
player_id == state.current_player,
state._x.board.reshape((size, size)),
-state._x.board.reshape((size, size)),
)
def _observe(state: GameState, color: Optional[Array] = None, size: int = 11) -> Array:
if color is None:
color = state.color

Check warning on line 154 in pgx/hex.py

View check run for this annotation

Codecov / codecov/patch

pgx/hex.py#L154

Added line #L154 was not covered by tests

board = jax.lax.select(color == state.color, state.board, -state.board)
board = board.reshape((size, size))

my_board = board * 1 > 0
opp_board = board * -1 > 0
ones = jnp.ones_like(my_board)
color = jax.lax.select(player_id == state.current_player, state._x.turn, 1 - state._x.turn)
color = color * ones
can_swap = state.legal_action_mask[-1] * ones
can_swap = (state.step_count == 1) * ones
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line changes the observation in terminated state.
can_swap plane changes from True to False.


return jnp.stack([my_board, opp_board, color, can_swap], 2, dtype=jnp.bool_)

Expand All @@ -179,18 +179,14 @@
return jnp.where(on_board, xs * size + ys, -1)


def _is_game_end(board, size, turn):
def _is_terminal(x: GameState, size):
top, bottom = jax.lax.cond(
turn == 0,
lambda: (board[:size], board[-size:]),
lambda: (board[::size], board[size - 1 :: size]),
x.color == 0,
lambda: (x.board[::size], x.board[size - 1 :: size]),
lambda: (x.board[:size], x.board[-size:]),
)

def check_same_id_exist(_id):
return (_id > 0) & (_id == bottom).any()
return (_id < 0) & (_id == bottom).any()

return jax.vmap(check_same_id_exist)(top).any()


def _get_abs_board(state):
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)