diff --git a/.bazelrc b/.bazelrc index 331da99127cd4c..25eeff429f4c81 100644 --- a/.bazelrc +++ b/.bazelrc @@ -611,10 +611,53 @@ build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cud build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10" build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7" +# build:asan --strip=never +# build:asan --copt -fsanitize=address +# build:asan --copt -DADDRESS_SANITIZER +# build:asan --copt -O1 +# build:asan --copt -g +# build:asan --copt -fno-omit-frame-pointer +# build:asan --linkopt -fsanitize=address + +# build:asan --crosstool_top //tools/lrte:toolchain +# build:asan --compiler clang build:asan --strip=never build:asan --copt -fsanitize=address build:asan --copt -DADDRESS_SANITIZER build:asan --copt -O1 build:asan --copt -g build:asan --copt -fno-omit-frame-pointer -build:asan --linkopt -fsanitize=address \ No newline at end of file +build:asan --linkopt -fsanitize=address + +# Thread sanitizer +# bazel build --config tsan +# build:tsan --crosstool_top //tools/lrte:toolchain +# build:tsan --compiler clang +build:tsan --strip=never +build:tsan --copt -fsanitize=thread +build:tsan --copt -DTHREAD_SANITIZER +build:tsan --copt -DDYNAMIC_ANNOTATIONS_ENABLED=1 +build:tsan --copt -DDYNAMIC_ANNOTATIONS_EXTERNAL_IMPL=1 +build:tsan --copt -O1 +build:tsan --copt -fno-omit-frame-pointer +build:tsan --linkopt -fsanitize=thread + +# --config msan: Memory sanitizer +# build:msan --crosstool_top //tools/lrte:toolchain +# build:msan --compiler clang +build:msan --strip=never +build:msan --copt -fsanitize=memory +build:msan --copt -DADDRESS_SANITIZER +build:msan --copt -O1 +build:msan --copt -fno-omit-frame-pointer +build:msan --linkopt -fsanitize=memory + +# --config ubsan: Undefined Behavior Sanitizer +# build:ubsan --crosstool_top //tools/lrte:toolchain +# build:ubsan --compiler clang +build:ubsan --strip=never +build:ubsan --copt -fsanitize=undefined +build:ubsan --copt -O1 +build:ubsan --copt -fno-omit-frame-pointer +build:ubsan --linkopt -fsanitize=undefined +build:ubsan --linkopt -lubsan \ No newline at end of file diff --git a/tensorflow/cc/examples/mcts.h b/tensorflow/cc/examples/mcts.h index c50e3174ca0629..64da2b286edb49 100644 --- a/tensorflow/cc/examples/mcts.h +++ b/tensorflow/cc/examples/mcts.h @@ -33,17 +33,14 @@ class SnakeMctsAdapter { } double value() const { - if (move_queued_) { - return 0; - } - switch (state_.game_state()) { case GameState::P1_WIN: return 1; case GameState::P2_WIN: return -1; - default: + case GameState::DRAW: return 0; + default: } } @@ -62,7 +59,19 @@ class SnakeMctsAdapter { return state_.p1_view().valid_move(d); } - bool is_terminal() const { return state_.is_terminal(); } + bool is_terminal() const { + return state_.is_terminal() || (valid_actions().count() == 0); + } + + void print() const { + state_.print(); + if (move_queued_) { + std::cout << "p1 queue move: " << Direction_Name(p1_queued_move_) + << std::endl; + } else { + std::cout << "no move queued for p1" << std::endl; + } + } private: SnakeBoard16 state_; @@ -72,15 +81,18 @@ class SnakeMctsAdapter { class Node { public: - explicit Node(SnakeMctsAdapter state, Node* parent) + explicit Node(SnakeMctsAdapter state) : Node(state, Direction::UP, nullptr) {} + + explicit Node(SnakeMctsAdapter state, Direction action, Node* parent) : state_(state), parent_(parent), + action_(action), is_terminal_(state_.is_terminal()), valid_actions_(state_.valid_actions()) {} - SnakeMctsAdapter state_; - Node* parent_; - Direction action_; + const SnakeMctsAdapter state_; + Node* const parent_; + const Direction action_; bool is_terminal_; bool is_fully_expanded() const { return num_children_expanded == valid_actions_.count(); @@ -99,21 +111,29 @@ class Node { } auto clone_state = state_; clone_state.execute(static_cast(i)); - children_[i] = new Node(clone_state, this); + children_[i] = new Node(clone_state, static_cast(i), this); ++num_children_expanded; return children_[i]; } + state_.print(); + CHECK(false) << "should never happen " << num_children_expanded; return nullptr; } const std::bitset valid_actions_; int num_children_expanded = 0; - std::array children_; + std::array children_ = {}; + ~Node() { + for (Node* n : children_) { + if (n) { + delete n; + } + } + } }; double random_rollout(const SnakeMctsAdapter& state) { auto s = state; - std::random_device - rd; // Will be used to obtain a seed for the random number engine + std::random_device rd; std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd() std::uniform_int_distribution<> distrib(Direction_MIN, Direction_MAX); @@ -123,6 +143,7 @@ double random_rollout(const SnakeMctsAdapter& state) { s.execute(candidate_d); } } + LOG(ERROR) << s.value(); return s.value(); } @@ -131,13 +152,13 @@ class Mcts { Mcts(std::function rollout) : rollout_(rollout) {} Direction Search(SnakeMctsAdapter state) { - root_ = new Node(state, nullptr); + root_ = std::make_unique(state); - for (int i = 0; i < 100; ++i) { + for (int i = 0; i < 10000; ++i) { execute_round(); } - auto* best_child = get_best_child(root_, 0.0); + auto* best_child = get_best_child(root_.get(), 0.0); return best_child->action_; } @@ -150,7 +171,7 @@ class Mcts { } void execute_round() { - Node* node = select_node(root_); + Node* node = select_node(root_.get()); double reward = rollout_(node->state_); backpropogate(node, reward); } @@ -160,57 +181,31 @@ class Mcts { if (node->is_fully_expanded()) { node = get_best_child(node, exploration_constant_); } else { - return expand(node); + return node->expand(); } } return node; } - Node* expand(Node* node) { - return nullptr; - /* actions = node.state.getPossibleActions() - for action in actions: - if action not in node.children: - newNode = treeNode(node.state.takeAction(action), node) - node.children[action] = newNode - if len(actions) == len(node.children): - node.isFullyExpanded = True - return newNode - - raise Exception("Should never reach here") */ - } - - Node* get_best_child(Node* node, double exploration_value) { + Node* get_best_child(Node* node, double exploration_value) const { return *std::max_element(node->children_.begin(), node->children_.end(), [exploration_value](Node* a, Node* b) { + if (a == nullptr && b == nullptr) { + return false; + } + if (a == nullptr || b == nullptr) { + return a == nullptr; + } return a->ucb(exploration_value) < b->ucb(exploration_value); }); - // bestValue = float("-inf") - /* bestNodes = [] - for child in node.children.values(): - nodeValue = node.state.getCurrentPlayer() * child.totalReward / -child.numVisits + explorationValue * math.sqrt( 2 * math.log(node.numVisits) / -child.numVisits) - if nodeValue > bestValue: bestValue = nodeValue bestNodes = -[child] elif nodeValue == bestValue: bestNodes.append(child) return -random.choice(bestNodes) */ } private: float exploration_constant_ = 2.0; std::function rollout_; - Node* root_; + std::unique_ptr root_; }; -/* -def randomPolicy(state): - while not state.isTerminal(): - try: - action = random.choice(state.getPossibleActions()) - except IndexError: - raise Exception("Non-terminal state has no possible actions: " + -str(state)) state = state.takeAction(action) return state.getReward() -*/ } // namespace snake diff --git a/tensorflow/cc/examples/mcts_test.cc b/tensorflow/cc/examples/mcts_test.cc index a42b5c08f59373..5caa31e2bb4f16 100644 --- a/tensorflow/cc/examples/mcts_test.cc +++ b/tensorflow/cc/examples/mcts_test.cc @@ -26,7 +26,11 @@ TEST(Mcts, MctsTest) { SnakeMctsAdapter adapter(state); Mcts mcts(random_rollout); - mcts.Search(adapter); + auto d = mcts.Search(adapter); + LOG(ERROR) << Direction_Name(d); + + std::vector a = {-1.0, -0.89}; + LOG(ERROR) << *std::max_element(a.begin(), a.end()); } } // namespace snake \ No newline at end of file diff --git a/tensorflow/cc/examples/snake.h b/tensorflow/cc/examples/snake.h index e5f347eb047e1b..cb248d3136320d 100644 --- a/tensorflow/cc/examples/snake.h +++ b/tensorflow/cc/examples/snake.h @@ -120,6 +120,14 @@ class SnakeBoard { bool valid_move(Direction d) const { return board.is_unoccupied(player.peek(d)); } + bool has_move() const { + for (int i = Direction_MIN; i < Direction_ARRAYSIZE; ++i) { + if (board.is_unoccupied(player.peek(static_cast(i)))) { + return true; + } + } + return false; + } }; PlayerView p1_view() const { return {p1_, p2_, *this}; }