Skip to content

Commit

Permalink
add pro rule for caro game
Browse files Browse the repository at this point in the history
  • Loading branch information
NTT123 authored Jul 5, 2022
1 parent 9c6f96b commit 98dd2d1
Showing 1 changed file with 44 additions and 32 deletions.
76 changes: 44 additions & 32 deletions caro_game.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Caro (Gomoku) game mechanics"""
"""Caro (Gomoku) game mechanics
Implement Pro rule. Reference: http://gomokuworld.com/gomoku/2
"""

from typing import Tuple

Expand Down Expand Up @@ -46,21 +50,18 @@ class CaroGame(Enviroment):
board: chex.Array
who_play: chex.Array
terminated: chex.Array
winner: chex.Array
num_cols: int
num_rows: int
count: chex.Array

def __init__(self, num_cols: int = 9, num_rows: int = 9):
def __init__(self, num_cols: int = 9, num_rows: int = 9, pro_rule_dist: int = 3):
super().__init__()
assert num_cols % 2 == 1 and num_rows % 2 == 1
assert pro_rule_dist in [3, 5]
self.pro_rule_dist = pro_rule_dist
self.num_rows = num_rows
self.winner_checker = CaroWinnerChecker()
self.board = jnp.zeros((num_rows * num_cols,), dtype=jnp.int32)
self.who_play = jnp.array(1, dtype=jnp.int32)
self.terminated = jnp.array(0, dtype=jnp.bool_)
self.winner = jnp.array(0, dtype=jnp.int32)
self.count = jnp.array(0, dtype=jnp.int32)
self.num_cols = num_cols
self.winner_checker = CaroWinnerChecker()
self.reset()

def num_actions(self):
Expand All @@ -70,12 +71,9 @@ def invalid_actions(self) -> chex.Array:
return self.board != 0

def reset(self):
assert self.num_rows % 2 == 1 and self.num_cols % 2 == 1
self.board = jnp.zeros((self.num_rows * self.num_cols,), dtype=jnp.int32)
self.board = self.board.at[self.num_rows * self.num_cols // 2].set(1)
self.who_play = jnp.array(-1, dtype=jnp.int32)
self.board = jnp.zeros((self.num_rows * self.num_cols), dtype=jnp.int32)
self.who_play = jnp.array(1, dtype=jnp.int32)
self.terminated = jnp.array(0, dtype=jnp.bool_)
self.winner = jnp.array(0, dtype=jnp.int32)
self.count = jnp.array(0, dtype=jnp.int32)

@pax.pure
Expand All @@ -84,11 +82,25 @@ def step(self, action: chex.Array) -> Tuple["CaroGame", chex.Array]:
An invalid move will terminate the game with reward -1.
"""
invalid_move = self.board[action] != 0
i, j = jnp.divmod(action, self.num_cols)
mid_i = self.num_rows // 2
mid_j = self.num_cols // 2
d_i = jnp.abs(mid_i - i)
d_j = jnp.abs(mid_j - j)
not_at_center = jnp.logical_or(d_i != 0, d_j != 0)
near_center = jnp.logical_and(
d_i < self.pro_rule_dist, d_j < self.pro_rule_dist
)
is_first_move = self.count == 0
is_third_move = self.count == 2
invalid_first_move = jnp.logical_and(is_first_move, not_at_center)
invalid_third_move = jnp.logical_and(is_third_move, near_center)
invalid_move = jnp.logical_or(invalid_first_move, invalid_third_move)
invalid_move = jnp.logical_or(invalid_move, self.board[action] != 0)
board_ = self.board.at[action].set(self.who_play)
self.board = select_tree(self.terminated, self.board, board_)
self.winner = self.winner_checker(self.observation())
reward_ = self.winner * self.who_play
winner = self.winner_checker(self.observation())
reward_ = winner * self.who_play
self.who_play = -self.who_play
self.count = self.count + 1
self.terminated = jnp.logical_or(self.terminated, reward_ != 0)
Expand All @@ -101,13 +113,17 @@ def step(self, action: chex.Array) -> Tuple["CaroGame", chex.Array]:

def step_xy(self, x: int, y: int):
"""step function with 2d actions."""
return self.step(y * self.num_cols + x)
return self.step(x * self.num_cols + y)

def render(self) -> None:
"""Render the game on screen."""
board = self.observation()
for row in reversed(range(self.num_rows)):
print(row, end=" ")
print(end=" ")
for col in range(self.num_cols):
print(chr(ord("a") + col), end=" ")
print()
for row in range(self.num_rows):
print(chr(ord("a") + row), end=" ")
for col in range(self.num_cols):
if board[row, col].item() == 1:
print("X", end=" ")
Expand All @@ -117,8 +133,6 @@ def render(self) -> None:
print(".", end=" ")
print()
print(end=" ")
for col in range(self.num_cols):
print(col, end=" ")
print()

def observation(self) -> chex.Array:
Expand Down Expand Up @@ -150,14 +164,12 @@ def symmetries(self, state, action_weights):

if __name__ == "__main__":
game = CaroGame()
while not game.is_terminated().item():
game.render()
i = input("> ")
a, b = list(i.strip().replace(" ", ""))
a, b = ord(a) - ord("a"), ord(b) - ord("a")
game, reward = game.step_xy(a, b)

print("Final board")
game.render()
game, reward = game.step_xy(3, 1)
game, reward = game.step_xy(2, 4)
game, reward = game.step_xy(3, 3)
game, reward = game.step_xy(3, 4)
game, reward = game.step_xy(3, 5)
game, reward = game.step_xy(1, 4)
game, reward = game.step_xy(3, 2)
game, reward = game.step_xy(5, 4)
game.render()
print("Reward", reward)

0 comments on commit 98dd2d1

Please sign in to comment.