Skip to content

Commit

Permalink
[Hex] Extract game specific attributes (sotetsuk#1287)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 2, 2024
1 parent 18799f8 commit a5c90f1
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 38 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _make_hex_dwg(dwg, state: HexState, config):
GRID_SIZE = config["GRID_SIZE"] / 2 # 六角形の1辺
BOARD_SIZE = int(state._size)
BOARD_SIZE = int(state._x.size)
color_set = config["COLOR_SET"]

# background
Expand Down
6 changes: 3 additions & 3 deletions pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ def _set_config_by_state(self, _state: State): # noqa: C901
from pgx._src.dwg.hex import _make_hex_dwg, four_dig

self.config["GRID_SIZE"] = 30
size = int(jnp.array(_state._size).ravel()[0])
self.config["BOARD_WIDTH"] = four_dig(size * 1.5) # type:ignore
self.config["BOARD_HEIGHT"] = four_dig(size * jnp.sqrt(3) / 2) # type:ignore
size = int(jnp.array(_state._x.size).ravel()[0])
self.config["BOARD_WIDTH"] = four_dig(size * 1.5) # type:ignore
self.config["BOARD_HEIGHT"] = four_dig(size * jnp.sqrt(3) / 2) # type:ignore
self._make_dwg_group = _make_hex_dwg # type:ignore
if (self.config["COLOR_THEME"] is None and self.config["COLOR_THEME"] == "dark") or self.config[
"COLOR_THEME"
Expand Down
58 changes: 33 additions & 25 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand All @@ -25,6 +26,20 @@
TRUE = jnp.bool_(True)


class GameState(NamedTuple):
size: Array = jnp.int32(11)
# 0(black), 1(white)
turn: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
# .
# .
# .
# [110, 111, 112, ..., 119, 120]]
board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
Expand All @@ -34,18 +49,7 @@ class State(core.State):
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(11 * 11 + 1, dtype=jnp.bool_).at[-1].set(FALSE)
_step_count: Array = jnp.int32(0)
# --- Hex specific ---
_size: Array = jnp.int32(11)
# 0(black), 1(white)
_turn: Array = jnp.int32(0)
# 11x11 board
# [[ 0, 1, 2, ..., 8, 9, 10],
# [ 11, 12, 13, ..., 19, 20, 21],
# .
# .
# .
# [110, 111, 112, ..., 119, 120]]
_board: Array = jnp.zeros(11 * 11, jnp.int32) # <0(oppo), 0(empty), 0<(self)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
Expand Down Expand Up @@ -89,12 +93,12 @@ def num_players(self) -> int:

def _init(rng: PRNGKey, size: int) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(_size=size, current_player=current_player) # type:ignore
return State(_x=GameState(size=size), current_player=current_player) # type:ignore


def _step(state: State, action: Array, size: int) -> State:
set_place_id = action + 1
board = state._board.at[action].set(set_place_id)
board = state._x.board.at[action].set(set_place_id)
neighbour = _neighbour(action, size)

def merge(i, b):
Expand All @@ -106,7 +110,7 @@ def merge(i, b):
)

board = jax.lax.fori_loop(0, 6, merge, board)
won = _is_game_end(board, size, state._turn)
won = _is_game_end(board, size, state._x.turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
Expand All @@ -115,8 +119,10 @@ def merge(i, b):

state = state.replace( # type:ignore
current_player=1 - state.current_player,
_turn=1 - state._turn,
_board=board * -1,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
rewards=reward,
terminated=won,
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(state._step_count == 1),
Expand All @@ -126,31 +132,33 @@ def merge(i, b):


def _swap(state: State, size: int) -> State:
ix = jnp.nonzero(state._board, size=1)[0]
ix = jnp.nonzero(state._x.board, size=1)[0]
row = ix // size
col = ix % size
swapped_ix = col * size + row
set_place_id = swapped_ix + 1
board = state._board.at[ix].set(0).at[swapped_ix].set(set_place_id)
board = state._x.board.at[ix].set(0).at[swapped_ix].set(set_place_id)
return state.replace( # type: ignore
current_player=1 - state.current_player,
_turn=1 - state._turn,
_board=board * -1,
_x=GameState(
turn=1 - state._x.turn,
board=board * -1,
),
legal_action_mask=state.legal_action_mask.at[:-1].set(board == 0).at[-1].set(FALSE),
)


def _observe(state: State, player_id: Array, size) -> Array:
board = jax.lax.select(
player_id == state.current_player,
state._board.reshape((size, size)),
-state._board.reshape((size, size)),
state._x.board.reshape((size, size)),
-state._x.board.reshape((size, size)),
)

my_board = board * 1 > 0
opp_board = board * -1 > 0
ones = jnp.ones_like(my_board)
color = jax.lax.select(player_id == state.current_player, state._turn, 1 - state._turn)
color = jax.lax.select(player_id == state.current_player, state._x.turn, 1 - state._x.turn)
color = color * ones
can_swap = state.legal_action_mask[-1] * ones

Expand Down Expand Up @@ -185,4 +193,4 @@ def check_same_id_exist(_id):


def _get_abs_board(state):
return jax.lax.cond(state._turn == 0, lambda: state._board, lambda: state._board * -1)
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)
18 changes: 9 additions & 9 deletions tests/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_merge():
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# fmt:on
assert jnp.all(state._board == expected)
assert jnp.all(state._x.board == expected)


def test_swap():
Expand All @@ -50,26 +50,26 @@ def test_swap():
assert ~state.legal_action_mask[-1]
state = step(state, 1)
state.save_svg("tests/assets/hex/swap_01.svg")
assert (state._board != 0).sum() == 1
assert state._board[1] == -2
assert (state._x.board != 0).sum() == 1
assert state._x.board[1] == -2
assert state.legal_action_mask[-1]
state = step(state, 121) # swap!
state.save_svg("tests/assets/hex/swap_02.svg")
assert (state._board != 0).sum() == 1
assert state._board[11] == -12
assert (state._x.board != 0).sum() == 1
assert state._x.board[11] == -12
assert ~state.legal_action_mask[-1]

key = jax.random.PRNGKey(0)
state = init(key=key)
state = step(state, 0)
state.save_svg("tests/assets/hex/swap_03.svg")
assert (state._board != 0).sum() == 1
assert state._board[0] == -1
assert (state._x.board != 0).sum() == 1
assert state._x.board[0] == -1
assert state.legal_action_mask[-1]
state = step(state, 121) # swap!
state.save_svg("tests/assets/hex/swap_04.svg")
assert (state._board != 0).sum() == 1
assert state._board[0] == -1
assert (state._x.board != 0).sum() == 1
assert state._x.board[0] == -1
assert ~state.legal_action_mask[-1]

key = jax.random.PRNGKey(0)
Expand Down

0 comments on commit a5c90f1

Please sign in to comment.