Skip to content

Commit

Permalink
uncomment logger
Browse files Browse the repository at this point in the history
  • Loading branch information
herilalaina committed Jun 25, 2020
1 parent b2c40d8 commit 400aae7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 27 deletions.
30 changes: 12 additions & 18 deletions mosaic/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,15 @@ def MCT_SEARCH(self):
"""
self.logger.info(
"#########################Iteration={0}##################################".format(self.n_iter))
self.logger.info("Begin SELECTION")
"## ITERATION={0} ##".format(self.n_iter))
front = self.TREEPOLICY()
self.logger.info("End SELECTION")

self.logger.info("Begin PLAYOUT")
reward, config = self.PLAYOUT(front)
self.logger.info("End PLAYOUT")

if config is None:
return 0, None

self.logger.info("Begin BACKUP")
self.BACKUP(front, reward)
self.logger.info("End BACKUP")
self.n_iter += 1

return reward, config
Expand All @@ -123,7 +117,7 @@ def TREEPOLICY(self):
return self.EXPAND(node)
else:
if not self.tree.fully_expanded(node, self.env):
self.logger.info("Not fully expanded.")
# self.logger.info("Not fully expanded.")
return self.EXPAND(node)
else:
current_node = self.tree.get_info_node(node)
Expand All @@ -137,45 +131,45 @@ def TREEPOLICY(self):
[x[1] for x in children],
[x[2] for x in children],
state=self.tree.get_path_to_node(node))
self.logger.info("Selection\t node={0}".format(node))
# self.logger.info("Selection\t node={0}".format(node))
else:
self.logger.error(
"Empty list of valid children\n current node {0}\t List of children {1}".format(
current_node,
self.tree.get_children(node)))
# self.logger.error(
# "Empty list of valid children\n current node {0}\t List of children {1}".format(
# current_node,
# self.tree.get_children(node)))
return node
return node

def EXPAND(self, node):
"""Expand child node."""
st_time = time.time()
self.logger.info("Expand on node {0}\n Current history: {1}".format(node, self.tree.get_path_to_node(node)))
# self.logger.info("Expand on node {0}\n Current history: {1}".format(node, self.tree.get_path_to_node(node)))
name, value, terminal = self.policy.expansion(self.env.next_move,
[self.tree.get_path_to_node(node),
self.tree.get_children(node, info=["name", "value"])])
id = self.tree.add_node(name=name, value=value,
terminal=terminal, parent_node=node)
self.logger.info("Expand\t id={0}\t name={1}\t value={2}\t terminal={3}".format(
# self.logger.info("Expand\t id={0}\t name={1}\t value={2}\t terminal={3}".format(
id, name, value, terminal))
return id

def PLAYOUT(self, node_id):
"""Playout policy."""

self.logger.info("Playout on : {0}".format(self.tree.get_path_to_node(node_id)))
# self.logger.info("Playout on : {0}".format(self.tree.get_path_to_node(node_id)))

st_time = time.time()
try:
playout_node = self.env.rollout(self.tree.get_path_to_node(node_id))
except Exception as e:
self.logger.error("Add node %s to not possible state: %s" % (node_id, e))
# self.logger.error("Add node %s to not possible state: %s" % (node_id, e))
self.tree.set_attribute(node_id, "invalid", True)
return 0, None

score = self.policy.evaluate(self.env._evaluate, [playout_node])

self.logger.info(
"Playout\t param={0}\t score={1}\t exec time={2}".format(playout_node, score, time.time() - st_time))
"param={0}\t score={1}\t exec time={2}".format(playout_node, score, time.time() - st_time))
return score, playout_node

def BACKUP(self, node, reward):
Expand Down
17 changes: 8 additions & 9 deletions mosaic/strategy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,16 @@ def selection(self, parent, ids, vals, visits, state=None):

beta = (time.time() - self.policy_arg["start_time"]) / self.policy_arg["time_budget"]

self.logger.info("################################ Selection ##############################")
self.logger.info("vals", vals)
self.logger.info("visits", visits)
self.logger.info("probas", probas)
self.logger.info("c=", self.policy_arg["c"])
self.logger.info("beta=", beta)
# self.logger.info("## SELECTION ##")
# self.logger.info("vals", vals)
# self.logger.info("visits", visits)
# self.logger.info("probas", probas)
# self.logger.info("c=", self.policy_arg["c"])
# self.logger.info("beta=", beta)
res = [val + self.policy_arg["c"] * (prob) * math.sqrt(sum(visits)) / (vis + 1)
for vis, val, prob, prob_gen in zip(visits, vals, probas, probas_general)]
self.logger.info("Final selection policy ", res)
self.logger.info("Selected ", np.argmax(res))
self.logger.info("#########################################################################")
# self.logger.info("Final selection policy ", res)
# self.logger.info("Selected ", np.argmax(res))

return ids[np.argmax(res)]

Expand Down

0 comments on commit 400aae7

Please sign in to comment.