Skip to content

Commit

Permalink
Update eval_step
Browse files Browse the repository at this point in the history
Former-commit-id: 174c026
  • Loading branch information
daochenzha committed May 27, 2021
1 parent 5b7463e commit 8477756
Show file tree
Hide file tree
Showing 21 changed files with 48 additions and 31 deletions.
7 changes: 6 additions & 1 deletion rlcard/agents/cfr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,15 @@ def eval_step(self, state):
Returns:
action (int): Predicted action
info (dict): A dictionary containing information
'''
probs = self.action_probs(state['obs'].tostring(), list(state['legal_actions'].keys()), self.average_policy)
action = np.random.choice(len(probs), p=probs)
return action, probs

info = {}
info['probs'] = {state['raw_legal_actions'][i]: probs[list(state['legal_actions'].keys())[i]] for i in range(len(state['legal_actions']))}

return action, info

def get_state(self, player_id):
''' Get state_str of the player
Expand Down
6 changes: 5 additions & 1 deletion rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,15 @@ def eval_step(self, state):
Returns:
action (int): an action id
info (dict): A dictionary containing information
'''
q_values = self.predict(state)
best_action = np.argmax(q_values)

return best_action, None
info = {}
info['values'] = {state['raw_legal_actions'][i]: q_values[list(state['legal_actions'].keys())[i]] for i in range(len(state['legal_actions']))}

return best_action, info

def predict(self, state):
''' Predict the masked Q-values
Expand Down
3 changes: 1 addition & 2 deletions rlcard/agents/human_agents/blackjack_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probs (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}

def _print_state(state, raw_legal_actions, action_record):
''' Print out the state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,5 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probabilities (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}
3 changes: 1 addition & 2 deletions rlcard/agents/human_agents/leduc_holdem_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probs (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}

def _print_state(state, action_record):
''' Print out the state
Expand Down
3 changes: 1 addition & 2 deletions rlcard/agents/human_agents/limit_holdem_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probs (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}

def _print_state(state, action_record):
''' Print out the state
Expand Down
3 changes: 1 addition & 2 deletions rlcard/agents/human_agents/nolimit_holdem_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probs (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}

def _print_state(state, action_record):
''' Print out the state
Expand Down
3 changes: 1 addition & 2 deletions rlcard/agents/human_agents/uno_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def eval_step(self, state):
Returns:
action (int): the action predicted (randomly chosen) by the random agent
probs (list): The list of action probabilities
'''
return self.step(state), []
return self.step(state), {}

def _print_state(state, action_record):
''' Print out the state of a given player
Expand Down
7 changes: 5 additions & 2 deletions rlcard/agents/nfsp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,21 @@ def eval_step(self, state):
Returns:
action (int): An action id.
info (dict): A dictionary containing information
'''
if self.evaluate_with == 'best_response':
action, probs = self._rl_agent.eval_step(state)
action, info = self._rl_agent.eval_step(state)
elif self.evaluate_with == 'average_policy':
obs = state['obs']
legal_actions = list(state['legal_actions'].keys())
probs = self._act(obs)
probs = remove_illegal(probs, legal_actions)
action = np.random.choice(len(probs), p=probs)
info = {}
info['probs'] = {state['raw_legal_actions'][i]: probs[list(state['legal_actions'].keys())[i]] for i in range(len(state['legal_actions']))}
else:
raise ValueError("'evaluate_with' should be either 'average_policy' or 'best_response'.")
return action, probs
return action, info

def sample_episode_policy(self):
''' Sample average/best_response policy
Expand Down
4 changes: 4 additions & 0 deletions rlcard/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,8 @@ def eval_step(self, state):
probs = [0 for _ in range(self.num_actions)]
for i in state['legal_actions']:
probs[i] = 1/len(state['legal_actions'])

info = {}
info['probs'] = {state['raw_legal_actions'][i]: probs[list(state['legal_actions'].keys())[i]] for i in range(len(state['legal_actions']))}

return self.step(state), probs
3 changes: 2 additions & 1 deletion rlcard/envs/blackjack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from collections import OrderedDict

from rlcard.envs import Env
from rlcard.games.blackjack import Game
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_scores_and_A(hand):
dealer_score, _ = get_scores_and_A(dealer_cards)
obs = np.array([my_score, dealer_score])

legal_actions = {i: None for i in range(len(self.actions))}
legal_actions = OrderedDict({i: None for i in range(len(self.actions))})
extracted_state = {'obs': obs, 'legal_actions': legal_actions}
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in self.actions]
Expand Down
4 changes: 2 additions & 2 deletions rlcard/envs/doudizhu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter
from collections import Counter, OrderedDict
import numpy as np

from rlcard.envs import Env
Expand Down Expand Up @@ -82,7 +82,7 @@ def _extract_state(self, state):
landlord_num_cards_left,
teammate_num_cards_left))

extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions()}
extracted_state = OrderedDict({'obs': obs, 'legal_actions': self._get_legal_actions()})
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
extracted_state['action_record'] = self.action_recorder
Expand Down
5 changes: 3 additions & 2 deletions rlcard/envs/gin_rummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Date created: 2/12/2020
'''
import numpy as np
from collections import OrderedDict

from rlcard.envs import Env

Expand Down Expand Up @@ -56,7 +57,7 @@ def _extract_state(self, state): # 200213 don't use state ???
unknown_cards_rep = self._utils.encode_cards(unknown_cards)
rep = [hand_rep, top_discard_rep, dead_cards_rep, known_cards_rep, unknown_cards_rep]
obs = np.array(rep)
extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions()}
extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions(), 'raw_legal_actions': list(self._get_legal_actions().keys())}
return extracted_state

def get_payoffs(self):
Expand Down Expand Up @@ -93,4 +94,4 @@ def _get_legal_actions(self):
'''
legal_actions = self.game.judge.get_legal_actions()
legal_actions_ids = {action_event.action_id: None for action_event in legal_actions}
return legal_actions_ids
return OrderedDict(legal_actions_ids)
3 changes: 2 additions & 1 deletion rlcard/envs/leducholdem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import numpy as np
from collections import OrderedDict

import rlcard
from rlcard.envs import Env
Expand Down Expand Up @@ -46,7 +47,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = {self.actions.index(a): None for a in state['legal_actions']}
legal_actions = OrderedDict({self.actions.index(a): None for a in state['legal_actions']})
extracted_state['legal_actions'] = legal_actions

public_card = state['public_card']
Expand Down
3 changes: 2 additions & 1 deletion rlcard/envs/limitholdem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import numpy as np
from collections import OrderedDict

import rlcard
from rlcard.envs import Env
Expand Down Expand Up @@ -49,7 +50,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = {self.actions.index(a): None for a in state['legal_actions']}
legal_actions = OrderedDict({self.actions.index(a): None for a in state['legal_actions']})
extracted_state['legal_actions'] = legal_actions

public_cards = state['public_cards']
Expand Down
3 changes: 2 additions & 1 deletion rlcard/envs/mahjong.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from collections import OrderedDict

from rlcard.envs import Env
from rlcard.games.mahjong import Game
Expand Down Expand Up @@ -107,4 +108,4 @@ def _get_legal_actions(self):
print([len(p.pile) for p in self.game.players])
#print(self.game.get_state(self.game.round.current_player))
#exit()
return legal_action_id
return OrderedDict(legal_action_id)
3 changes: 2 additions & 1 deletion rlcard/envs/nolimitholdem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import numpy as np
from collections import OrderedDict

import rlcard
from rlcard.envs import Env
Expand Down Expand Up @@ -54,7 +55,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = {action.value: None for action in state['legal_actions']}
legal_actions = OrderedDict({action.value: None for action in state['legal_actions']})
extracted_state['legal_actions'] = legal_actions

public_cards = state['public_cards']
Expand Down
3 changes: 2 additions & 1 deletion rlcard/envs/uno.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from collections import OrderedDict

from rlcard.envs import Env
from rlcard.games.uno import Game
Expand Down Expand Up @@ -42,7 +43,7 @@ def _decode_action(self, action_id):
def _get_legal_actions(self):
legal_actions = self.game.get_legal_actions()
legal_ids = {ACTION_SPACE[action]: None for action in legal_actions}
return legal_ids
return OrderedDict(legal_ids)

def get_perfect_information(self):
''' Get the perfect information of the current state
Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_train(self):
for _ in range(100):
agent.train()

state = {'obs': np.array([1., 1., 0., 0., 0., 0.]), 'legal_actions': {0: None,2: None}}
state = {'obs': np.array([1., 1., 0., 0., 0., 0.]), 'legal_actions': {0: None,2: None}, 'raw_legal_actions': ['call', 'fold']}
action, _ = agent.eval_step(state)

self.assertIn(action, [0, 2])
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def test_train(self):
mlp_layers=[10,10],
device=torch.device('cpu'))

predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}, 'raw_legal_actions': ['call', 'raise']})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

for _ in range(num_steps):
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, True]
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}, 'raw_legal_actions': ['call', 'raise']}, True]
agent.feed(ts)

predicted_action = agent.step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_nfsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_train(self):
q_mlp_layers=[10,10],
device=torch.device('cpu'))

predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}, 'raw_legal_actions': ['call', 'raise']})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

Expand All @@ -43,5 +43,5 @@ def test_train(self):
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, True]
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}, 'raw_legal_actions': ['call', 'raise']}, True]
agent.feed(ts)

0 comments on commit 8477756

Please sign in to comment.