Skip to content

Commit

Permalink
Fix testing issue
Browse files Browse the repository at this point in the history
Former-commit-id: 436398e
  • Loading branch information
daochenzha committed May 7, 2021
1 parent d7d2fd7 commit f7a4050
Show file tree
Hide file tree
Showing 28 changed files with 47 additions and 90 deletions.
4 changes: 2 additions & 2 deletions rlcard/agents/cfr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def eval_step(self, state):
Returns:
action (int): Predicted action
'''
probs = self.action_probs(state['obs'].tostring(), state['legal_actions'], self.average_policy)
probs = self.action_probs(state['obs'].tostring(), list(state['legal_actions'].keys()), self.average_policy)
action = np.random.choice(len(probs), p=probs)
return action, probs

Expand All @@ -168,7 +168,7 @@ def get_state(self, player_id):
legal_actions (list): Indices of legal actions
'''
state = self.env.get_state(player_id)
return state['obs'].tostring(), state['legal_actions']
return state['obs'].tostring(), list(state['legal_actions'].keys())

def save(self):
''' Save model
Expand Down
4 changes: 2 additions & 2 deletions rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def step(self, state):
action (int): an action id
'''
A = self.predict(state['obs'])
A = remove_illegal(A, state['legal_actions'])
A = remove_illegal(A, list(state['legal_actions'].keys()))
action = np.random.choice(np.arange(len(A)), p=A)
return action

Expand All @@ -158,7 +158,7 @@ def eval_step(self, state):
action (int): an action id
'''
q_values = self.q_estimator.predict_nograd(np.expand_dims(state['obs'], 0))[0]
probs = remove_illegal(np.exp(q_values), state['legal_actions'])
probs = remove_illegal(np.exp(q_values), list(state['legal_actions'].keys()))
best_action = np.argmax(probs)
return best_action, probs

Expand Down
4 changes: 2 additions & 2 deletions rlcard/agents/nfsp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def step(self, state):
action (int): An action id
'''
obs = state['obs']
legal_actions = state['legal_actions']
legal_actions = list(state['legal_actions'].keys())
if self._mode == MODE.best_response:
probs = self._rl_agent.predict(obs)
self._add_transition(obs, probs)
Expand All @@ -198,7 +198,7 @@ def eval_step(self, state):
action, probs = self._rl_agent.eval_step(state)
elif self.evaluate_with == 'average_policy':
obs = state['obs']
legal_actions = state['legal_actions']
legal_actions = list(state['legal_actions'].keys())
probs = self._act(obs)
probs = remove_illegal(probs, legal_actions)
action = np.random.choice(len(probs), p=probs)
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_scores_and_A(hand):
dealer_score, _ = get_scores_and_A(dealer_cards)
obs = np.array([my_score, dealer_score])

legal_actions = [i for i in range(len(self.actions))]
legal_actions = {i: None for i in range(len(self.actions))}
extracted_state = {'obs': obs, 'legal_actions': legal_actions}
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in self.actions]
Expand Down
1 change: 1 addition & 0 deletions rlcard/envs/doudizhu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _extract_state(self, state):
others_hand = _cards2array(state['others_hand'])

last_action = ''
print(state)
if len(state['trace']) != 0:
if state['trace'][-1][1] == 'pass':
last_action = state['trace'][-2][1]
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/gin_rummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _get_legal_actions(self):
legal_actions (list): a list of legal actions' id
'''
legal_actions = self.game.judge.get_legal_actions()
legal_actions_ids = [action_event.action_id for action_event in legal_actions]
legal_actions_ids = {action_event.action_id: None for action_event in legal_actions}
return legal_actions_ids

def _load_model(self):
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/leducholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = [self.actions.index(a) for a in state['legal_actions']]
legal_actions = {self.actions.index(a): None for a in state['legal_actions']}
extracted_state['legal_actions'] = legal_actions

public_card = state['public_card']
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/limitholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = [self.actions.index(a) for a in state['legal_actions']]
legal_actions = {self.actions.index(a): None for a in state['legal_actions']}
extracted_state['legal_actions'] = legal_actions

public_cards = state['public_cards']
Expand Down
4 changes: 2 additions & 2 deletions rlcard/envs/mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def _get_legal_actions(self):
print(legal_actions)
legal_actions (list): a list of legal actions' id
'''
legal_action_id = []
legal_action_id = {}
legal_actions = self.game.get_legal_actions(self.game.get_state(self.game.round.current_player))
if legal_actions:
for action in legal_actions:
if isinstance(action, Card):
action = action.get_str()
action_id = self.action_id[action]
legal_action_id.append(action_id)
legal_action_id[action_id] = None
else:
print("##########################")
print("No Legal Actions")
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/nolimitholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _extract_state(self, state):
'''
extracted_state = {}

legal_actions = [action.value for action in state['legal_actions']]
legal_actions = {action.value: None for action in state['legal_actions']}
extracted_state['legal_actions'] = legal_actions

public_cards = state['public_cards']
Expand Down
4 changes: 2 additions & 2 deletions rlcard/envs/simpledoudizhu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def _get_legal_actions(self):
Returns:
legal_actions (list): a list of legal actions' id
'''
legal_action_id = []
legal_action_id = {}
legal_actions = self.game.state['actions']
if legal_actions:
for action in legal_actions:
for abstract in self._SPECIFIC_MAP[action]:
action_id = self._ACTION_SPACE[abstract]
if action_id not in legal_action_id:
legal_action_id.append(action_id)
legal_action_id[action_id] = None
return legal_action_id

def get_perfect_information(self):
Expand Down
2 changes: 1 addition & 1 deletion rlcard/envs/uno.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _decode_action(self, action_id):

def _get_legal_actions(self):
legal_actions = self.game.get_legal_actions()
legal_ids = [ACTION_SPACE[action] for action in legal_actions]
legal_ids = {ACTION_SPACE[action]: None for action in legal_actions}
return legal_ids

def get_perfect_information(self):
Expand Down
2 changes: 1 addition & 1 deletion rlcard/games/doudizhu/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_action_num():
Returns:
int: the total number of abstract actions of doudizhu
'''
return 309
return 27472

def get_player_id(self):
''' Return current player's id
Expand Down
4 changes: 2 additions & 2 deletions rlcard/games/doudizhu/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def step_back(self, players):
if (cards != 'pass'):
for card in cards:
# self.played_cards.remove(card)
self.played_cards[CARD_RANK_STR_INDEX[card]] -= 1
self.public['played_cards'] = self.cards_ndarray_to_list(self.played_cards)
self.played_cards[player_id][CARD_RANK_STR_INDEX[card]] -= 1
self.public['played_cards'] = self.cards_ndarray_to_str(self.played_cards)
greater_player_id = self.find_last_greater_player_id_in_trace()
if (greater_player_id is not None):
self.greater_player = players[greater_player_id]
Expand Down
29 changes: 0 additions & 29 deletions rlcard/games/doudizhu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,6 @@ def get_landlord_score(current_hand):
i += 1
return score


def get_optimal_action(probs, legal_actions, np_random):
''' Determine the optimal action from legal actions
according to the probabilities of abstract actions.
Args:
probs (list): list of probabilities of abstract actions
legal_actions (list): list of legal actions
Returns:
str: optimal legal action
'''
abstract_actions = [SPECIFIC_MAP[action] for action in legal_actions]
action_probs = []
for actions in abstract_actions:
max_prob = -1
for action in actions:
prob = probs[ACTION_SPACE[action]]
if prob > max_prob:
max_prob = prob
action_probs.append(max_prob)
optimal_prob = max(action_probs)
optimal_actions = [legal_actions[index] for index,
prob in enumerate(action_probs) if prob == optimal_prob]
if len(optimal_actions) > 1:
return np_random.choice(optimal_actions)
return optimal_actions[0]


def cards2str_with_suit(cards):
''' Get the corresponding string representation of cards with suit
Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_train(self):
for _ in range(100):
agent.train()

state = {'obs': np.array([1., 1., 0., 0., 0., 0.]), 'legal_actions': [0,2]}
state = {'obs': np.array([1., 1., 0., 0., 0., 0.]), 'legal_actions': {0: None,2: None}}
action, _ = agent.eval_step(state)

self.assertIn(action, [0, 2])
Expand Down
6 changes: 3 additions & 3 deletions tests/agents/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def test_train(self):
mlp_layers=[10,10],
device=torch.device('cpu'))

predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]})
predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

for _ in range(step_num):
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]}, True]
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, True]
agent.feed(ts)
state_dict = agent.get_state_dict()
self.assertIsInstance(state_dict, dict)

predicted_action = agent.step({'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]})
predicted_action = agent.step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)
6 changes: 3 additions & 3 deletions tests/agents/test_nfsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ def test_train(self):
q_mlp_layers=[10,10],
device=torch.device('cpu'))

predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]})
predicted_action, _ = agent.eval_step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

for _ in range(step_num):
agent.sample_episode_policy()
predicted_action = agent.step({'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]})
predicted_action = agent.step({'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}})
self.assertGreaterEqual(predicted_action, 0)
self.assertLessEqual(predicted_action, 1)

ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': [0, 1]}, True]
ts = [{'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, np.random.randint(2), 0, {'obs': np.random.random_sample((2,)), 'legal_actions': {0: None, 1: None}}, True]
agent.feed(ts)
state_dict = agent.get_state_dict()
self.assertIsInstance(state_dict, dict)
2 changes: 1 addition & 1 deletion tests/envs/determism_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def gather_observations(env, actions, num_rand_steps):
while not env.is_over() and action_idx < len(actions):
# Agent plays
rand_iter(num_rand_steps)
legals = state['legal_actions']
legals = list(state['legal_actions'].keys())
action = legals[actions[action_idx]%len(legals)]
# Environment steps
next_state, next_player_id = env.step(action)
Expand Down
8 changes: 4 additions & 4 deletions tests/envs/test_doudizhu_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestDoudizhuEnv(unittest.TestCase):
def test_reset_and_extract_state(self):
env = rlcard.make('doudizhu')
state, _ = env.reset()
self.assertEqual(state['obs'].size, 450)
self.assertEqual(state['obs'].size, 790)

def test_is_deterministic(self):
self.assertTrue(is_deterministic('doudizhu'))
Expand All @@ -27,7 +27,7 @@ def test_step(self):
env = rlcard.make('doudizhu')
_, player_id = env.reset()
player = env.game.players[player_id]
_, next_player_id = env.step(env.action_num-1)
_, next_player_id = env.step(env.action_num-2)
self.assertEqual(next_player_id, (player.player_id+1)%len(env.game.players))

def test_step_back(self):
Expand Down Expand Up @@ -62,8 +62,8 @@ def test_decode_action(self):
env.reset()
env.game.state['actions'] = ['33366', '33355']
env.game.judger.playable_cards[0] = ['5', '6', '55', '555', '33366', '33355']
decoded = env._decode_action(54)
self.assertEqual(decoded, '33366')
decoded = env._decode_action(3)
self.assertEqual(decoded, '6')
env.game.state['actions'] = ['444', '44466', '44455']
decoded = env._decode_action(29)
self.assertEqual(decoded, '444')
Expand Down
2 changes: 1 addition & 1 deletion tests/envs/test_gin_rummy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_get_legal_actions(self):
def test_step(self):
env = rlcard.make('gin-rummy')
state, _ = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
_, player_id = env.step(action)
current_player_id = env.game.round.get_current_player().player_id
self.assertEqual(player_id, current_player_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/envs/test_leducholdem_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_step(self):
env = rlcard.make('leduc-holdem')
state, player_id = env.reset()
self.assertEqual(player_id, env.get_player_id())
action = state['legal_actions'][0]
action = list(state['legal_actions'].keys())[0]
_, player_id = env.step(action)
self.assertEqual(player_id, env.get_player_id())

Expand Down
2 changes: 1 addition & 1 deletion tests/envs/test_limitholdem_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_step(self):
env = rlcard.make('limit-holdem')
state, player_id = env.reset()
self.assertEqual(player_id, env.get_player_id())
action = state['legal_actions'][0]
action = list(state['legal_actions'].keys())[0]
_, player_id = env.step(action)
self.assertEqual(player_id, env.get_player_id())

Expand Down
6 changes: 3 additions & 3 deletions tests/envs/test_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def test_get_legal_actions(self):
def test_step(self):
env = rlcard.make('mahjong')
state, _ = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
_, player_id = env.step(action)
self.assertEqual(player_id, env.game.round.current_player)

def test_step_back(self):
env = rlcard.make('mahjong', config={'allow_step_back':True})
state, player_id = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
env.step(action)
env.step_back()
self.assertEqual(env.game.round.current_player, player_id)

env = rlcard.make('mahjong', config={'allow_step_back':False})
state, player_id = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
env.step(action)
# env.step_back()
self.assertRaises(Exception, env.step_back)
Expand Down
2 changes: 1 addition & 1 deletion tests/envs/test_nolimitholdem_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_step(self):
env = rlcard.make('no-limit-holdem')
state, player_id = env.reset()
self.assertEqual(player_id, env.get_player_id())
action = state['legal_actions'][0]
action = list(state['legal_actions'].keys())[0]
_, player_id = env.step(action)
self.assertEqual(player_id, env.get_player_id())

Expand Down
6 changes: 3 additions & 3 deletions tests/envs/test_uno_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ def test_get_legal_actions(self):
def test_step(self):
env = rlcard.make('uno')
state, _ = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
_, player_id = env.step(action)
self.assertEqual(player_id, env.game.round.current_player)

def test_step_back(self):
env = rlcard.make('uno', config={'allow_step_back':True})
state, player_id = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
env.step(action)
env.step_back()
self.assertEqual(env.game.round.current_player, player_id)

env = rlcard.make('uno', config={'allow_step_back':False})
state, player_id = env.reset()
action = np.random.choice(state['legal_actions'])
action = np.random.choice(list(state['legal_actions'].keys()))
env.step(action)
# env.step_back()
self.assertRaises(Exception, env.step_back)
Expand Down
Loading

0 comments on commit f7a4050

Please sign in to comment.