Skip to content

Commit

Permalink
[Hex] Extract logic functions (#1288)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 2, 2024
1 parent a5c90f1 commit c343851
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 58 deletions.
6 changes: 5 additions & 1 deletion pgx/_src/dwg/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import jax.numpy as jnp

from pgx.hex import State as HexState
from pgx.hex import _get_abs_board


def _get_abs_board(state):
return state._x.board if state._x.color == 0 else state._x.board * -1


r3 = jnp.sqrt(3)

Expand Down
110 changes: 53 additions & 57 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import partial
from typing import NamedTuple
from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp
Expand All @@ -29,7 +29,7 @@
class GameState(NamedTuple):
size: Array = jnp.int32(11)
# 0(black), 1(white)
turn: Array = jnp.int32(0)
step_count: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
Expand All @@ -39,6 +39,10 @@ class GameState(NamedTuple):
# [110, 111, 112, ..., 119, 120]]
board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)

@property
def color(self) -> Array:
return self.step_count % 2


@dataclass
class State(core.State):
Expand All @@ -63,20 +67,37 @@ def __init__(self, *, size: int = 11):
self.size = size

def _init(self, key: PRNGKey) -> State:
return partial(_init, size=self.size)(rng=key)
current_player = jnp.int32(jax.random.bernoulli(key))
return State(_x=_init(self.size), current_player=current_player) # type:ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
return jax.lax.cond(
x = jax.lax.cond(
action != self.size * self.size,
lambda: partial(_step, size=self.size)(state, action),
lambda: partial(_swap, size=self.size)(state),
lambda: partial(_step, size=self.size)(state._x, action),
lambda: partial(_swap, size=self.size)(state._x),
)

terminated = _is_terminal(x, self.size)
reward = jax.lax.cond(
terminated,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
)

return state.replace( # type:ignore
current_player=1 - state.current_player,
legal_action_mask=state.legal_action_mask.at[:-1].set(state._x.board == 0).at[-1].set(state._step_count == 1),
rewards=reward,
terminated=terminated,
_x=x,
)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
return partial(_observe, size=self.size)(state, player_id)
color = jax.lax.select(player_id == state.current_player, state._x.color, 1 - state._x.color)
return _observe(state._x, color, self.size)

@property
def id(self) -> core.EnvId:
Expand All @@ -91,14 +112,13 @@ def num_players(self) -> int:
return 2


def _init(rng: PRNGKey, size: int) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(_x=GameState(size=size), current_player=current_player) # type:ignore
def _init(size: int) -> GameState:
return GameState(size=size)


def _step(state: State, action: Array, size: int) -> State:
def _step(state: GameState, action: Array, size: int) -> GameState:
set_place_id = action + 1
board = state._x.board.at[action].set(set_place_id)
board = state.board.at[action].set(set_place_id)
neighbour = _neighbour(action, size)

def merge(i, b):
Expand All @@ -110,57 +130,37 @@ def merge(i, b):
)

board = jax.lax.fori_loop(0, 6, merge, board)
won = _is_game_end(board, size, state._x.turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
return state._replace(
step_count=state.step_count + 1,
board=board * -1,
)

state = state.replace( # type:ignore
current_player=1 - state.current_player,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
rewards=reward,
terminated=won,
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(state._step_count == 1),
)

return state


def _swap(state: State, size: int) -> State:
ix = jnp.nonzero(state._x.board, size=1)[0]
def _swap(state: GameState, size: int) -> GameState:
ix = jnp.nonzero(state.board, size=1)[0]
row = ix // size
col = ix % size
swapped_ix = col * size + row
set_place_id = swapped_ix + 1
board = state._x.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state.replace( # type: ignore
current_player=1 - state.current_player,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(FALSE),
board = state.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state._replace(
step_count=state.step_count + 1,
board=board * -1,
)


def _observe(state: State, player_id: Array, size) -> Array:
board = jax.lax.select(
player_id == state.current_player,
state._x.board.reshape((size, size)),
-state._x.board.reshape((size, size)),
)
def _observe(state: GameState, color: Optional[Array] = None, size: int = 11) -> Array:
if color is None:
color = state.color

board = jax.lax.select(color == state.color, state.board, -state.board)
board = board.reshape((size, size))

my_board = board * 1 > 0
opp_board = board * -1 > 0
ones = jnp.ones_like(my_board)
color = jax.lax.select(player_id == state.current_player, state._x.turn, 1 - state._x.turn)
color = color * ones
can_swap = state.legal_action_mask[-1] * ones
can_swap = (state.step_count == 1) * ones

return jnp.stack([my_board, opp_board, color, can_swap], 2, dtype=jnp.bool_)

Expand All @@ -179,18 +179,14 @@ def _neighbour(xy, size):
return jnp.where(on_board, xs * size + ys, -1)


def _is_game_end(board, size, turn):
def _is_terminal(x: GameState, size):
top, bottom = jax.lax.cond(
turn == 0,
lambda: (board[:size], board[-size:]),
lambda: (board[::size], board[size - 1 :: size]),
x.color == 0,
lambda: (x.board[::size], x.board[size - 1 :: size]),
lambda: (x.board[:size], x.board[-size:]),
)

def check_same_id_exist(_id):
return (_id > 0) & (_id == bottom).any()
return (_id < 0) & (_id == bottom).any()

return jax.vmap(check_same_id_exist)(top).any()


def _get_abs_board(state):
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)

0 comments on commit c343851

Please sign in to comment.