Skip to content

Commit

Permalink
Unify num players and actions
Browse files Browse the repository at this point in the history
Former-commit-id: aac122c
  • Loading branch information
daochenzha committed May 14, 2021
1 parent cde2734 commit 5ef0639
Show file tree
Hide file tree
Showing 14 changed files with 29 additions and 29 deletions.
14 changes: 7 additions & 7 deletions examples/human/blackjack_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

# Make environment and enable human mode
# Set 'record_action' to True because we need it to print results
player_num = 2
env = rlcard.make('blackjack', config={'record_action': True, 'game_player_num': player_num})
human_agent = HumanAgent(env.action_num)
random_agent = RandomAgent(env.action_num)
num_players = 2
env = rlcard.make('blackjack', config={'record_action': True, 'game_num_players': num_players})
human_agent = HumanAgent(env.num_actions)
random_agent = RandomAgent(env.num_actions)
env.set_agents([human_agent, random_agent])

print(">> Blackjack human agent")
Expand All @@ -30,7 +30,7 @@
state = []
_action_list = []

for i in range(player_num):
for i in range(num_players):
final_state.append(trajectories[i][-1])
state.append(final_state[i]['raw_obs'])

Expand All @@ -45,12 +45,12 @@
print('=============== Dealer hand ===============')
print_card(state[0]['state'][1])

for i in range(player_num):
for i in range(num_players):
print('=============== Player {} Hand ==============='.format(i))
print_card(state[i]['state'][0])

print('=============== Result ===============')
for i in range(player_num):
for i in range(num_players):
if payoffs[i] == 1:
print('Player {} win {} chip!'.format(i, payoffs[i]))
elif payoffs[i] == 0:
Expand Down
4 changes: 2 additions & 2 deletions examples/human/gin_rummy_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@

def make_gin_rummy_env() -> 'GinRummyEnv':
gin_rummy_env = rlcard.make('gin-rummy')
# north_agent = RandomAgent(action_num=gin_rummy_env.action_num)
# north_agent = RandomAgent(num_actions=gin_rummy_env.num_actions)
north_agent = GinRummyNoviceRuleAgent()
south_agent = HumanAgent(gin_rummy_env.action_num)
south_agent = HumanAgent(gin_rummy_env.num_actions)
gin_rummy_env.set_agents([north_agent, south_agent])
gin_rummy_env.game.judge.scorer = scorers.GinRummyScorer(get_payoff=scorers.get_payoff_gin_rummy_v0)
return gin_rummy_env
Expand Down
2 changes: 1 addition & 1 deletion examples/human/leduc_holdem_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Make environment
# Set 'record_action' to True because we need it to print results
env = rlcard.make('leduc-holdem', config={'record_action': True})
human_agent = HumanAgent(env.action_num)
human_agent = HumanAgent(env.num_actions)
cfr_agent = models.load('leduc-holdem-cfr').agents[0]
env.set_agents([human_agent, cfr_agent])

Expand Down
4 changes: 2 additions & 2 deletions examples/human/limit_holdem_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# Make environment and enable human mode
# Set 'record_action' to True because we need it to print results
env = rlcard.make('limit-holdem', config={'record_action': True})
human_agent = HumanAgent(env.action_num)
agent_0 = RandomAgent(action_num=env.action_num)
human_agent = HumanAgent(env.num_actions)
agent_0 = RandomAgent(num_actions=env.num_actions)
env.set_agents([human_agent, agent_0])

print(">> Limit Hold'em random agent")
Expand Down
6 changes: 3 additions & 3 deletions examples/human/nolimit_holdem_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# Set 'record_action' to True because we need it to print results
env = rlcard.make('no-limit-holdem', config={'record_action': True})

human_agent = HumanAgent(env.action_num)
human_agent2 = HumanAgent(env.action_num)
# random_agent = RandomAgent(action_num=env.action_num)
human_agent = HumanAgent(env.num_actions)
human_agent2 = HumanAgent(env.num_actions)
# random_agent = RandomAgent(num_actions=env.num_actions)

env.set_agents([human_agent, human_agent2])

Expand Down
2 changes: 1 addition & 1 deletion examples/human/uno_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Make environment and enable human mode
# Set 'record_action' to True because we need it to print results
env = rlcard.make('uno', config={'record_action': True})
human_agent = HumanAgent(env.action_num)
human_agent = HumanAgent(env.num_actions)
cfr_agent = models.load('uno-rule-v1').agents[0]
env.set_agents([human_agent, cfr_agent])

Expand Down
6 changes: 3 additions & 3 deletions rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(self, num_actions=2, learning_rate=0.001, state_shape=None, mlp_lay
''' Initilalize an Estimator object.
Args:
action_num (int): the number output actions
num_actions (int): the number output actions
state_shape (list): the shape of the state space
mlp_layers (list): size of outputs of mlp layers
device (torch.device): whether to use cpu or gpu
Expand Down Expand Up @@ -303,10 +303,10 @@ def update(self, s, a, y):
a = torch.from_numpy(a).long().to(self.device)
y = torch.from_numpy(y).float().to(self.device)

# (batch, state_shape) -> (batch, action_num)
# (batch, state_shape) -> (batch, num_actions)
q_as = self.qnet(s)

# (batch, action_num) -> (batch, )
# (batch, num_actions) -> (batch, )
Q = torch.gather(q_as, dim=-1, index=a.unsqueeze(-1)).squeeze(-1)

# update model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ The example uses the following:
```python
def make_gin_rummy_env() -> 'GinRummyEnv':
gin_rummy_env = rlcard.make('gin-rummy')
# north_agent = RandomAgent(action_num=gin_rummy_env.action_num)
# north_agent = RandomAgent(num_actions=gin_rummy_env.num_actions)
north_agent = GinRummyNoviceRuleAgent()
south_agent = HumanAgent(gin_rummy_env.action_num)
south_agent = HumanAgent(gin_rummy_env.num_actions)
gin_rummy_env.set_agents([north_agent, south_agent])
return gin_rummy_env
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, make_gin_rummy_env: Callable[[], 'GinRummyEnv'] = None):
@staticmethod
def _make_gin_rummy_env() -> 'GinRummyEnv':
gin_rummy_env = rlcard.make('gin-rummy')
north_agent = RandomAgent(action_num=gin_rummy_env.action_num)
south_agent = HumanAgent(gin_rummy_env.action_num)
north_agent = RandomAgent(num_actions=gin_rummy_env.num_actions)
south_agent = HumanAgent(gin_rummy_env.num_actions)
gin_rummy_env.set_agents([north_agent, south_agent])
return gin_rummy_env
4 changes: 2 additions & 2 deletions rlcard/agents/nfsp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def train_sl(self):
# (batch, state_size)
info_states = torch.from_numpy(np.array(info_states)).float().to(self.device)

# (batch, action_num)
# (batch, num_actions)
eval_action_probs = torch.from_numpy(np.array(action_probs)).float().to(self.device)

# (batch, action_num)
# (batch, num_actions)
log_forecast_action_probs = self.policy_network(info_states)

ce_loss = - (eval_action_probs * log_forecast_action_probs).sum(dim=-1).mean()
Expand Down
2 changes: 1 addition & 1 deletion rlcard/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, num_actions):
''' Initilize the random agent
Args:
action_num (int): The size of the ouput action space
num_actions (int): The size of the ouput action space
'''
self.use_raw = False
self.num_actions = num_actions
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, config):
step_back.
There can be some game specific configurations, e.g., the
number of players in the game. These fields should start with
'game_', e.g., 'game_player_num' which specify the number of
'game_', e.g., 'game_num_players' which specify the number of
players in the game. Since these configurations may be game-specific,
The default settings should be put in the Env class. For example,
the default game configurations for Blackjack should be in
Expand Down
2 changes: 1 addition & 1 deletion tests/games/test_doudizhu_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class TestDoudizhuGame(unittest.TestCase):

def test_get_player_num(self):
def test_get_num_players(self):
game = Game()
num_players = game.get_num_players()
self.assertEqual(num_players, 3)
Expand Down
2 changes: 1 addition & 1 deletion tests/games/test_limitholdem_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestLimitholdemMethods(unittest.TestCase):
def test_get_num_actions(self):
game = Game()
num_players = game.get_num_players()
self.assertEqual(player_num, 2)
self.assertEqual(num_players, 2)

def test_get_num_actions(self):
game = Game()
Expand Down

0 comments on commit 5ef0639

Please sign in to comment.