Skip to content

Commit

Permalink
Clean code
Browse files Browse the repository at this point in the history
Former-commit-id: 31b6e45
  • Loading branch information
daochenzha committed May 6, 2021
1 parent 25bac35 commit d100952
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 102 deletions.
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ python:
- "3.6"
- "3.7"
install:
- pip install -e .
- pip install torch==1.6.0
- pip install -e .[torch]
before_script:
- pip install python-coveralls
- pip install pytest-cover
Expand Down
8 changes: 3 additions & 5 deletions rlcard/envs/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def get_scores_and_A(hand):

legal_actions = [i for i in range(len(self.actions))]
extracted_state = {'obs': obs, 'legal_actions': legal_actions}
if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in self.actions]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in self.actions]
extracted_state['action_record'] = self.action_recorder
return extracted_state

def get_payoffs(self):
Expand Down
12 changes: 3 additions & 9 deletions rlcard/envs/doudizhu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,9 @@ def _extract_state(self, state):
self._encode_cards(obs[5], state['played_cards'])

extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions()}
if self.allow_raw_data:
extracted_state['raw_obs'] = state
# TODO: state['actions'] can be None, may have bugs
if state['actions'] == None:
extracted_state['raw_legal_actions'] = []
else:
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
extracted_state['action_record'] = self.action_recorder
return extracted_state

def get_payoffs(self):
Expand Down
19 changes: 3 additions & 16 deletions rlcard/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ def __init__(self, config):
'seed' (int) - A environment local random seed.
'allow_step_back' (boolean) - True if allowing
step_back.
'allow_raw_data' (boolean) - True if allow
raw obs in state['raw_obs'] and raw legal actions in
state['raw_legal_actions'].
There can be some game specific configurations, e.g., the
number of players in the game. These fields should start with
'game_', e.g., 'game_player_num' which specify the number of
Expand All @@ -28,10 +25,7 @@ def __init__(self, config):
TODO: Support more game configurations in the future.
'''
self.allow_step_back = self.game.allow_step_back = config['allow_step_back']
self.allow_raw_data = config['allow_raw_data']
self.record_action = config['record_action']
if self.record_action:
self.action_recorder = []
self.action_recorder = []

# Game specific configurations
# Currently only support blackjack、limit-holdem、no-limit-holdem
Expand Down Expand Up @@ -65,8 +59,7 @@ def reset(self):
(int): The begining player
'''
state, player_id = self.game.init_game()
if self.record_action:
self.action_recorder = []
self.action_recorder = []
return self._extract_state(state), player_id

def step(self, action, raw_action=False):
Expand All @@ -87,8 +80,7 @@ def step(self, action, raw_action=False):

self.timestep += 1
# Record the action for human interface
if self.record_action:
self.action_recorder.append([self.get_player_id(), action])
self.action_recorder.append([self.get_player_id(), action])
next_state, player_id = self.game.step(action)

return self._extract_state(next_state), player_id
Expand Down Expand Up @@ -124,11 +116,6 @@ def set_agents(self, agents):
agents (list): List of Agent classes
'''
self.agents = agents
# If at least one agent needs raw data, we set self.allow_raw_data = True
for agent in self.agents:
if agent.use_raw:
self.allow_raw_data = True
break

def run(self, is_training=False):
'''
Expand Down
8 changes: 3 additions & 5 deletions rlcard/envs/leducholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ def _extract_state(self, state):
obs[state['all_chips'][1]+20] = 1
extracted_state['obs'] = obs

if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
extracted_state['action_record'] = self.action_recorder

return extracted_state

Expand Down
9 changes: 4 additions & 5 deletions rlcard/envs/limitholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def _extract_state(self, state):
obs[52 + i * 5 + num] = 1
extracted_state['obs'] = obs

if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
extracted_state['action_record'] = self.action_recorder

return extracted_state

def get_payoffs(self):
Expand Down
9 changes: 4 additions & 5 deletions rlcard/envs/mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def _extract_state(self, state):
obs = np.array(rep)

extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions()}
if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['action_cards']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['action_cards']]
extracted_state['action_record'] = self.action_recorder

return extracted_state

def get_payoffs(self):
Expand Down
9 changes: 4 additions & 5 deletions rlcard/envs/nolimitholdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ def _extract_state(self, state):
obs[53] = float(max(all_chips))
extracted_state['obs'] = obs

if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
extracted_state['action_record'] = self.action_recorder

return extracted_state

def get_payoffs(self):
Expand Down
2 changes: 0 additions & 2 deletions rlcard/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# Default Config
DEFAULT_CONFIG = {
'allow_step_back': False,
'allow_raw_data': False,
'record_action' : False,
'seed': None,
}

Expand Down
9 changes: 4 additions & 5 deletions rlcard/envs/simpledoudizhu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ def _extract_state(self, state):
self._encode_cards(obs[5], state['played_cards'])

extracted_state = {'obs': obs, 'legal_actions': self._get_legal_actions()}
if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
extracted_state['action_record'] = self.action_recorder

return extracted_state

def get_payoffs(self):
Expand Down
9 changes: 3 additions & 6 deletions rlcard/envs/uno.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@ def _extract_state(self, state):
encode_target(obs[3], state['target'])
legal_action_id = self._get_legal_actions()
extracted_state = {'obs': obs, 'legal_actions': legal_action_id}
if self.allow_raw_data:
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [
a for a in state['legal_actions']]
if self.record_action:
extracted_state['action_record'] = self.action_recorder
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
extracted_state['action_record'] = self.action_recorder
return extracted_state

def get_payoffs(self):
Expand Down
2 changes: 1 addition & 1 deletion rlcard/games/doudizhu/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_state(self, player_id):
player = self.players[player_id]
others_hands = self._get_others_current_hand(player)
if self.is_over():
actions = None
actions = []
else:
actions = list(player.available_actions(self.round.greater_player, self.judger))
state = player.get_state(self.round.public, others_hands, actions)
Expand Down
2 changes: 1 addition & 1 deletion rlcard/games/simpledoudizhu/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_state(self, player_id):
player = self.players[player_id]
others_hands = self._get_others_current_hand(player)
if self.is_over():
actions = None
actions = []
else:
actions = list(player.available_actions(self.round.greater_player, self.judger))
state = player.get_state(self.round.public, others_hands, actions)
Expand Down
22 changes: 0 additions & 22 deletions rlcard/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,28 +149,6 @@ def reorganize(trajectories, payoffs):
new_trajectories[player].append(transition)
return new_trajectories

def set_global_seed(seed):
''' Set the global see for reproducing results
Args:
seed (int): The seed
Note: If using other modules with randomness, they also need to be seeded
'''
if seed is not None:
import subprocess
import sys

reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze'])
installed_packages = [r.decode().split('==')[0] for r in reqs.split()]
if 'torch' in installed_packages:
import torch
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
np.random.seed(seed)
import random
random.seed(seed)

def remove_illegal(action_probs, legal_actions):
''' Remove illegal actions and normalize the
probability vector
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
with open("README.md", "r", encoding="utf8") as fh:
long_description = fh.read()

extras = {
'torch': ['torch==1.6', 'matplotlib'],
}

def _get_version():
with open('rlcard/__init__.py') as f:
for line in f:
Expand Down Expand Up @@ -43,6 +47,7 @@ def _get_version():
'numpy>=1.16.3',
'termcolor'
],
extras_require=extras,
requires_python='>=3.6',
classifiers=[
"Programming Language :: Python :: 3.9",
Expand Down
4 changes: 1 addition & 3 deletions tests/games/test_doudizhu_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ def test_proceed_game(self):
player = game.players[player_id]
self.assertEqual((player.player_id+1)%len(game.players), next_player_id)
player_id = next_player_id
if not game.is_over():
self.assertIsNotNone(state['actions'])
for player_id in range(3):
state = game.get_state(player_id)
self.assertIsNone(state['actions'])
self.assertEqual(state['actions'], [])

def test_step_back(self):
#case 1: action, stepback
Expand Down
4 changes: 1 addition & 3 deletions tests/games/test_simpledoudizhu_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def test_proceed_game(self):
player = game.players[player_id]
self.assertEqual((player.player_id+1)%len(game.players), next_player_id)
player_id = next_player_id
if not game.is_over():
self.assertIsNotNone(state['actions'])
for player_id in range(3):
state = game.get_state(player_id)
self.assertIsNone(state['actions'])
self.assertEqual(state['actions'], [])

def test_step_back(self):
#case 1: action, stepback
Expand Down
8 changes: 1 addition & 7 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import numpy as np
from rlcard.utils.utils import init_54_deck, init_standard_deck, rank2int, print_card, elegent_form, reorganize, set_global_seed, tournament
from rlcard.utils.utils import init_54_deck, init_standard_deck, rank2int, print_card, elegent_form, reorganize, tournament
import rlcard
from rlcard.agents.random_agent import RandomAgent

Expand Down Expand Up @@ -35,17 +35,11 @@ def test_reorganize(self):
trajectories = reorganize([[[1,2],1,[4,5]]], [1])
self.assertEqual(np.array(trajectories).shape, (1, 1, 5))

def test_set_global_seed(self):
set_global_seed(0)
self.assertEqual(np.random.get_state()[1][0], 0)

def test_tournament(self):
env = rlcard.make('leduc-holdem')
env.set_agents([RandomAgent(env.action_num), RandomAgent(env.action_num)])
payoffs = tournament(env,1000)
self.assertEqual(len(payoffs), 2)



if __name__ == '__main__':
unittest.main()

0 comments on commit d100952

Please sign in to comment.