Skip to content

Commit

Permalink
[Shogi] Extract game specific attributes (sotetsuk#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Oct 30, 2024
1 parent 55dcb36 commit e9d7c87
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 71 deletions.
6 changes: 3 additions & 3 deletions pgx/_src/dwg/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901
if state._turn == 1:
if state._x.turn == 1:
from pgx.shogi import _flip

state = _flip(state)
Expand All @@ -17,7 +17,7 @@ def _sort_pieces(state, p1_hand, p2_hand):
"""
ShogiStateのhandを飛、角、金、銀、桂、香、歩の順にする
"""
hands = state._hand.flatten()[::-1]
hands = state._x.hand.flatten()[::-1]
tmp = hands
hands = hands.at[0].set(tmp[1])
hands = hands.at[1].set(tmp[2])
Expand Down Expand Up @@ -129,7 +129,7 @@ def _sort_pieces(state, p1_hand, p2_hand):
p1_pieces_g = dwg.g()
p2_pieces_g = dwg.g()
one_hot_board = np.zeros((29, 81))
board = state._board
board = state._x.board
for i in range(81):
piece = board[i]
one_hot_board[piece, i] = 1
Expand Down
8 changes: 4 additions & 4 deletions pgx/_src/shogi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _to_sfen(state):
"""
# NOTE: input must be flipped if white turn

pb = jnp.rot90(state._board.reshape((9, 9)), k=3)
pb = jnp.rot90(state._x.board.reshape((9, 9)), k=3)
sfen = ""
# fmt: off
board_char_dir = ["", "P", "L", "N", "S", "B", "R", "G", "K", "+P", "+L", "+N", "+S", "+B", "+R", "p", "l", "n", "s", "b", "r", "g", "k", "+p", "+l", "+n", "+s", "+b", "+r"]
Expand All @@ -209,18 +209,18 @@ def _to_sfen(state):
else:
sfen += " "
# Turn
if state._turn == 0:
if state._x.turn == 0:
sfen += "b "
else:
sfen += "w "
# Hand (prisoners)
if jnp.all(state._hand == 0):
if jnp.all(state._x.hand == 0):
sfen += "-"
else:
for i in range(2):
for j in range(7):
piece_type = hand_dir[i * 7 + j]
num_piece = state._hand.flatten()[piece_type]
num_piece = state._x.hand.flatten()[piece_type]
if num_piece == 0:
continue
if num_piece >= 2:
Expand Down
120 changes: 62 additions & 58 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


from typing import NamedTuple
from functools import partial

import jax
Expand Down Expand Up @@ -73,6 +74,16 @@
ALL_SQ = jnp.arange(81)


class GameState(NamedTuple):
turn: Array = jnp.int32(0) # 0 or 1
board: Array = INIT_PIECE_BOARD # (81,) flip in turn
hand: Array = jnp.zeros((2, 7), dtype=jnp.int32) # flip in turn
# cache
# Redundant information used only in _is_checked for speeding-up
cache_m2b: Array = -jnp.ones(8, dtype=jnp.int32)
cache_king: Array = jnp.int32(44)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
Expand All @@ -82,14 +93,7 @@ class State(core.State):
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK # (27 * 81,)
observation: Array = jnp.zeros((119, 9, 9), dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Shogi specific ---
_turn: Array = jnp.int32(0) # 0 or 1
_board: Array = INIT_PIECE_BOARD # (81,) flip in turn
_hand: Array = jnp.zeros((2, 7), dtype=jnp.int32) # flip in turn
# cache
# Redundant information used only in _is_checked for speeding-up
_cache_m2b: Array = -jnp.ones(8, dtype=jnp.int32)
_cache_king: Array = jnp.int32(44)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
Expand All @@ -99,7 +103,7 @@ def env_id(self) -> core.EnvId:
def _from_board(turn, piece_board: Array, hand: Array):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_turn=turn, _board=piece_board, _hand=hand) # type: ignore
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: _flip(state), lambda: state)
# fmt: on
Expand All @@ -111,7 +115,7 @@ def _from_sfen(sfen):
return jax.jit(State._from_board)(turn, pb, hand).replace(_step_count=jnp.int32(step_count)) # type: ignore

def _to_sfen(self):
state = self if self._turn % 2 == 0 else _flip(self)
state = self if self._x.turn % 2 == 0 else _flip(self)
return _to_sfen(state)


Expand Down Expand Up @@ -215,11 +219,11 @@ def _from_dlshogi_action(state: State, action: Array):
is_promotion = (10 <= direction) & (direction < 20)
# LEGAL_FROM_IDX[UP, 19] = [20, 21, ... -1]
legal_from_idx = LEGAL_FROM_IDX[direction % 10, to] # (81,)
from_cand = state._board[legal_from_idx] # (8,)
from_cand = state._x.board[legal_from_idx] # (8,)
mask = (legal_from_idx >= 0) & (PAWN <= from_cand) & (from_cand < OPP_PAWN)
i = jnp.argmax(mask)
from_ = jax.lax.select(is_drop, 0, legal_from_idx[i])
piece = jax.lax.select(is_drop, direction - 20, state._board[from_])
piece = jax.lax.select(is_drop, direction - 20, state._x.board[from_])
return Action(is_drop=is_drop, piece=piece, to=to, from_=from_, is_promotion=is_promotion) # type: ignore


Expand All @@ -236,7 +240,7 @@ def _step(state: State, action: Array):
state = _flip(state)
state = state.replace( # type: ignore
current_player=(state.current_player + 1) % 2,
_turn=(state._turn + 1) % 2,
_x=state._x._replace(turn=(state._x.turn + 1) % 2),
)
legal_action_mask = _legal_action_mask(state)
terminated = ~legal_action_mask.any()
Expand All @@ -255,41 +259,40 @@ def _step(state: State, action: Array):


def _step_move(state: State, action: Action) -> State:
pb = state._board
pb = state._x.board
# remove piece from the original position
pb = pb.at[action.from_].set(EMPTY)
# capture the opponent if exists
captured = pb[action.to] # suppose >= OPP_PAWN, -1 if EMPTY
hand = jax.lax.cond(
captured == EMPTY,
lambda: state._hand,
lambda: state._x.hand,
# add captured piece to my hand after
# (1) tuning opp piece into mine by (x + 14) % 28, and
# (2) filtering promoted piece by x % 8
lambda: state._hand.at[0, ((captured + 14) % 28) % 8].add(1),
lambda: state._x.hand.at[0, ((captured + 14) % 28) % 8].add(1),
)
# promote piece
piece = jax.lax.select(action.is_promotion, action.piece + 8, action.piece)
# set piece to the target position
pb = pb.at[action.to].set(piece)
# apply piece moves
return state.replace(_board=pb, _hand=hand) # type: ignore
return state.replace(_x=state._x._replace(board=pb, hand=hand)) # type: ignore


def _step_drop(state: State, action: Action) -> State:
# add piece to board
pb = state._board.at[action.to].set(action.piece)
pb = state._x.board.at[action.to].set(action.piece)
# remove piece from hand
hand = state._hand.at[0, action.piece].add(-1)
return state.replace(_board=pb, _hand=hand) # type: ignore
hand = state._x.hand.at[0, action.piece].add(-1)
return state.replace(_x=state._x._replace(board=pb, hand=hand)) # type: ignore


def _set_cache(state: State):
return state.replace( # type: ignore
_cache_m2b=jnp.nonzero(jax.vmap(_is_major_piece)(state._board), size=8, fill_value=-1)[0],
_cache_king=jnp.argmin(jnp.abs(state._board - KING)),
)

return state.replace(_x=state._x._replace( # type: ignore
cache_m2b=jnp.nonzero(jax.vmap(_is_major_piece)(state._x.board), size=8, fill_value=-1)[0],
cache_king=jnp.argmin(jnp.abs(state._x.board - KING)),
))

def _legal_action_mask(state: State):
# update cache
Expand Down Expand Up @@ -337,9 +340,9 @@ def is_legal_drop(i):

def _is_drop_pawn_mate(state: State):
# check pawn drop mate
opp_king_pos = jnp.argmin(jnp.abs(state._board - OPP_KING))
opp_king_pos = jnp.argmin(jnp.abs(state._x.board - OPP_KING))
to = opp_king_pos + 1
flip_state = _flip(state.replace(_board=state._board.at[to].set(PAWN))) # type: ignore
flip_state = _flip(state.replace(_x=state._x._replace(board=state._x.board.at[to].set(PAWN)))) # type: ignore
# Not checkmate if
# (1) can capture checking pawn, or
# (2) king can escape
Expand All @@ -359,17 +362,17 @@ def _is_drop_pawn_mate(state: State):


def _is_legal_drop_wo_piece(to: Array, state: State):
is_illegal = state._board[to] != EMPTY
is_illegal |= _is_checked(state.replace(_board=state._board.at[to].set(PAWN))) # type: ignore
is_illegal = state._x.board[to] != EMPTY
is_illegal |= _is_checked(state.replace(_x=state._x._replace(board=state._x.board.at[to].set(PAWN)))) # type: ignore
return ~is_illegal


def _is_legal_drop_wo_ignoring_check(piece: Array, to: Array, state: State):
is_illegal = state._board[to] != EMPTY
is_illegal = state._x.board[to] != EMPTY
# don't have the piece
is_illegal |= state._hand[0, piece] <= 0
is_illegal |= state._x.hand[0, piece] <= 0
# double pawn
is_illegal |= (piece == PAWN) & ((state._board == PAWN).reshape(9, 9).sum(axis=1) > 0)[to // 9]
is_illegal |= (piece == PAWN) & ((state._x.board == PAWN).reshape(9, 9).sum(axis=1) > 0)[to // 9]
# get stuck
is_illegal |= ((piece == PAWN) | (piece == LANCE)) & (to % 9 == 0)
is_illegal |= (piece == KNIGHT) & (to % 9 < 2)
Expand All @@ -383,13 +386,13 @@ def _is_legal_move_wo_pro(
):
ok = _is_pseudo_legal_move(from_, to, state)
ok &= ~_is_checked(
state.replace( # type: ignore
_board=state._board.at[from_].set(EMPTY).at[to].set(state._board[from_]),
_cache_king=jax.lax.select( # update cache
state._board[from_] == KING,
state.replace(_x=state._x._replace( # type: ignore
board=state._x.board.at[from_].set(EMPTY).at[to].set(state._x.board[from_]),
cache_king=jax.lax.select( # update cache
state._x.board[from_] == KING,
jnp.int32(to),
state._cache_king,
),
state._x.cache_king,
))
)
)
return ok
Expand All @@ -402,9 +405,9 @@ def _is_pseudo_legal_move(
):
ok = _is_pseudo_legal_move_wo_obstacles(from_, to, state)
# there is an obstacle between from_ and to
i = _major_piece_ix(state._board[from_])
i = _major_piece_ix(state._x.board[from_])
between_ix = BETWEEN_IX[i, from_, to, :]
is_illegal = (i >= 0) & ((between_ix >= 0) & (state._board[between_ix] != EMPTY)).any()
is_illegal = (i >= 0) & ((between_ix >= 0) & (state._x.board[between_ix] != EMPTY)).any()
return ok & ~is_illegal


Expand All @@ -413,7 +416,7 @@ def _is_pseudo_legal_move_wo_obstacles(
to: Array,
state: State,
):
board = state._board
board = state._x.board
# source is not my piece
piece = board[from_]
is_illegal = (from_ < 0) | ~((PAWN <= piece) & (piece < OPP_PAWN))
Expand All @@ -430,7 +433,7 @@ def _is_no_promotion_legal(
state: State,
):
# source is not my piece
piece = state._board[from_]
piece = state._x.board[from_]
# promotion
is_illegal = ((piece == PAWN) | (piece == LANCE)) & (to % 9 == 0) # Must promote
is_illegal |= (piece == KNIGHT) & (to % 9 < 2) # Must promote
Expand All @@ -443,7 +446,7 @@ def _is_promotion_legal(
state: State,
):
# source is not my piece
piece = state._board[from_]
piece = state._x.board[from_]
# promotion
is_illegal = (GOLD <= piece) & (piece <= DRAGON) # Pieces cannot promote
is_illegal |= (from_ % 9 >= 3) & (to % 9 >= 3) # irrelevant to the opponent's territory
Expand All @@ -452,8 +455,8 @@ def _is_promotion_legal(

def _is_checked(state):
# Use cached king position, simpler implementation is:
# jnp.argmin(jnp.abs(state.piece_board - KING))
king_pos = state._cache_king
# jnp.argmin(jnp.abs(state.pieceboard - KING))
king_pos = state._x.cache_king
flipped_king_pos = 80 - king_pos

@jax.vmap
Expand All @@ -467,7 +470,7 @@ def can_capture_king_local(from_):
# Simpler implementation without cache of major piece places
# from_ = CAN_MOVE_ANY[flipped_king_pos]
# return can_capture_king(from_).any()
from_ = 80 - state._cache_m2b
from_ = 80 - state._x.cache_m2b
from_ = jnp.where(from_ == 81, -1, from_)
neighbours = NEIGHBOUR_IX[flipped_king_pos]
return can_capture_king(from_).any() | can_capture_king_local(neighbours).any()
Expand All @@ -482,14 +485,15 @@ def _rotate(board: Array) -> Array:


def _flip(state):
empty_mask = state._board == EMPTY
pb = (state._board + 14) % 28
empty_mask = state._x.board == EMPTY
pb = (state._x.board + 14) % 28
pb = jnp.where(empty_mask, EMPTY, pb)
pb = pb[::-1]
return state.replace( # type: ignore
_board=pb,
_hand=state._hand[jnp.int32((1, 0))],
x = state._x._replace(
board=pb,
hand=state._x.hand[jnp.int32((1, 0))],
)
return state.replace(_x=x) # type: ignore


def _is_major_piece(piece):
Expand Down Expand Up @@ -534,24 +538,24 @@ def _observe(state: State, player_id: Array) -> Array:
def pieces(state):
# piece positions
my_pieces = jnp.arange(OPP_PAWN)
my_piece_feat = jax.vmap(lambda p: state._board == p)(my_pieces)
my_piece_feat = jax.vmap(lambda p: state._x.board == p)(my_pieces)
return my_piece_feat

def effect_all(state):
def effect(from_, to):
piece = state._board[from_]
piece = state._x.board[from_]
can_move = CAN_MOVE[piece, from_, to]
major_piece_ix = _major_piece_ix(piece)
between_ix = BETWEEN_IX[major_piece_ix, from_, to, :]
has_obstacles = jax.lax.select(
major_piece_ix >= 0,
((between_ix >= 0) & (state._board[between_ix] != EMPTY)).any(),
((between_ix >= 0) & (state._x.board[between_ix] != EMPTY)).any(),
FALSE,
)
return can_move & ~has_obstacles

effects = jax.vmap(jax.vmap(effect, (None, 0)), (0, None))(ALL_SQ, ALL_SQ)
mine = (PAWN <= state._board) & (state._board < OPP_PAWN)
mine = (PAWN <= state._x.board) & (state._x.board < OPP_PAWN)
return jnp.where(mine.reshape(81, 1), effects, FALSE)

def piece_and_effect(state):
Expand All @@ -560,7 +564,7 @@ def piece_and_effect(state):

@jax.vmap
def filter_effect(p):
mask = state._board == p
mask = state._x.board == p
return jnp.where(mask.reshape(81, 1), my_effect, FALSE).any(axis=0)

my_effect_feat = filter_effect(my_pieces)
Expand Down Expand Up @@ -595,8 +599,8 @@ def hand_feat(hand):
opp_piece_feat = opp_piece_feat[:, ::-1]
opp_effect_feat = opp_effect_feat[:, ::-1]
opp_effect_sum_feat = opp_effect_sum_feat[:, ::-1]
my_hand_feat = hand_feat(state._hand[0])
opp_hand_feat = hand_feat(state._hand[1])
my_hand_feat = hand_feat(state._x.hand[0])
opp_hand_feat = hand_feat(state._x.hand[1])
# NOTE: update cache
checked = jnp.tile(_is_checked(_set_cache(state)), reps=(1, 9, 9))
feat1 = [
Expand Down
Loading

0 comments on commit e9d7c87

Please sign in to comment.