forked from suragnair/alpha-zero-general
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCoach.py
51 lines (40 loc) · 1.79 KB
/
Coach.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from collections import deque
class Coach():
def __init__(self, game, nnet, mcts):
# maintain self.board, self.curPlayer
self.game = game
self.board = game.getInitBoard()
self.nnet = nnet
self.curPlayer = 1
self.numIters = 1000
self.numEps = 100
self.mcts = mcts(self.game, self.nnet)
self.maxlenOfQueue = 5000
# other hyperparams (numIters, numEps, MCTSParams etc)
def executeEpisode(self):
# performs one full game
# returns a list of training examples from this episode [ < s,a_vec,r > ]
trainExamples = []
self.board = game.getInitBoard()
while True:
pi = self.mcts.getActionProb(self.board, self.player)
canonicalBoard = self.game.getCanonicalForm(self.board,self.curPlayer)
trainExamples.append((self.canonicalBoard, self.curPlayer, actionProb, 0))
action = np.argmax(pi)
(self.board, self.curPlayer) = self.game.getNextState(self.board, self.curPlayer, action)
r = getGameEnded(self.board, self.curPlayer)
if r!=0:
for i in xrange(len(trainExamples)):
e = trainExamples[i]
e[3] = r if self.curPlayer == e[1] else -r
trainExamples[i] = (e[0],e[2],e[3])
break
return train_examples
def learn(self):
# performs numIters x numEps games
# after every Iter, retrains nnet and only updates if it wins > cutoff% games
trainExamples = deque([], maxlen=self.maxlenOfQueue)
for Iter in xrange(numIters):
for eps in xrange(numEps):
trainExamples.append(executeEpisode())
self.nnet.trainNNet(trainExamples)