Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
geligeli committed Jan 16, 2021
1 parent 984ee0b commit 648dad9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 55 deletions.
45 changes: 44 additions & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,53 @@ build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cud
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"

# build:asan --strip=never
# build:asan --copt -fsanitize=address
# build:asan --copt -DADDRESS_SANITIZER
# build:asan --copt -O1
# build:asan --copt -g
# build:asan --copt -fno-omit-frame-pointer
# build:asan --linkopt -fsanitize=address

# build:asan --crosstool_top //tools/lrte:toolchain
# build:asan --compiler clang
build:asan --strip=never
build:asan --copt -fsanitize=address
build:asan --copt -DADDRESS_SANITIZER
build:asan --copt -O1
build:asan --copt -g
build:asan --copt -fno-omit-frame-pointer
build:asan --linkopt -fsanitize=address
build:asan --linkopt -fsanitize=address

# Thread sanitizer
# bazel build --config tsan
# build:tsan --crosstool_top //tools/lrte:toolchain
# build:tsan --compiler clang
build:tsan --strip=never
build:tsan --copt -fsanitize=thread
build:tsan --copt -DTHREAD_SANITIZER
build:tsan --copt -DDYNAMIC_ANNOTATIONS_ENABLED=1
build:tsan --copt -DDYNAMIC_ANNOTATIONS_EXTERNAL_IMPL=1
build:tsan --copt -O1
build:tsan --copt -fno-omit-frame-pointer
build:tsan --linkopt -fsanitize=thread

# --config msan: Memory sanitizer
# build:msan --crosstool_top //tools/lrte:toolchain
# build:msan --compiler clang
build:msan --strip=never
build:msan --copt -fsanitize=memory
build:msan --copt -DADDRESS_SANITIZER
build:msan --copt -O1
build:msan --copt -fno-omit-frame-pointer
build:msan --linkopt -fsanitize=memory

# --config ubsan: Undefined Behavior Sanitizer
# build:ubsan --crosstool_top //tools/lrte:toolchain
# build:ubsan --compiler clang
build:ubsan --strip=never
build:ubsan --copt -fsanitize=undefined
build:ubsan --copt -O1
build:ubsan --copt -fno-omit-frame-pointer
build:ubsan --linkopt -fsanitize=undefined
build:ubsan --linkopt -lubsan
101 changes: 48 additions & 53 deletions tensorflow/cc/examples/mcts.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,14 @@ class SnakeMctsAdapter {
}

double value() const {
if (move_queued_) {
return 0;
}

switch (state_.game_state()) {
case GameState::P1_WIN:
return 1;
case GameState::P2_WIN:
return -1;
default:
case GameState::DRAW:
return 0;
default:
}
}

Expand All @@ -62,7 +59,19 @@ class SnakeMctsAdapter {
return state_.p1_view().valid_move(d);
}

bool is_terminal() const { return state_.is_terminal(); }
bool is_terminal() const {
return state_.is_terminal() || (valid_actions().count() == 0);
}

void print() const {
state_.print();
if (move_queued_) {
std::cout << "p1 queue move: " << Direction_Name(p1_queued_move_)
<< std::endl;
} else {
std::cout << "no move queued for p1" << std::endl;
}
}

private:
SnakeBoard16 state_;
Expand All @@ -72,15 +81,18 @@ class SnakeMctsAdapter {

class Node {
public:
explicit Node(SnakeMctsAdapter state, Node* parent)
explicit Node(SnakeMctsAdapter state) : Node(state, Direction::UP, nullptr) {}

explicit Node(SnakeMctsAdapter state, Direction action, Node* parent)
: state_(state),
parent_(parent),
action_(action),
is_terminal_(state_.is_terminal()),
valid_actions_(state_.valid_actions()) {}

SnakeMctsAdapter state_;
Node* parent_;
Direction action_;
const SnakeMctsAdapter state_;
Node* const parent_;
const Direction action_;
bool is_terminal_;
bool is_fully_expanded() const {
return num_children_expanded == valid_actions_.count();
Expand All @@ -99,21 +111,29 @@ class Node {
}
auto clone_state = state_;
clone_state.execute(static_cast<Direction>(i));
children_[i] = new Node(clone_state, this);
children_[i] = new Node(clone_state, static_cast<Direction>(i), this);
++num_children_expanded;
return children_[i];
}
state_.print();
CHECK(false) << "should never happen " << num_children_expanded;
return nullptr;
}
const std::bitset<Direction_ARRAYSIZE> valid_actions_;
int num_children_expanded = 0;
std::array<Node*, Direction_ARRAYSIZE> children_;
std::array<Node*, Direction_ARRAYSIZE> children_ = {};
~Node() {
for (Node* n : children_) {
if (n) {
delete n;
}
}
}
};

double random_rollout(const SnakeMctsAdapter& state) {
auto s = state;
std::random_device
rd; // Will be used to obtain a seed for the random number engine
std::random_device rd;
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
std::uniform_int_distribution<> distrib(Direction_MIN, Direction_MAX);

Expand All @@ -123,6 +143,7 @@ double random_rollout(const SnakeMctsAdapter& state) {
s.execute(candidate_d);
}
}
LOG(ERROR) << s.value();
return s.value();
}

Expand All @@ -131,13 +152,13 @@ class Mcts {
Mcts(std::function<double(const SnakeMctsAdapter&)> rollout)
: rollout_(rollout) {}
Direction Search(SnakeMctsAdapter state) {
root_ = new Node(state, nullptr);
root_ = std::make_unique<Node>(state);

for (int i = 0; i < 100; ++i) {
for (int i = 0; i < 10000; ++i) {
execute_round();
}

auto* best_child = get_best_child(root_, 0.0);
auto* best_child = get_best_child(root_.get(), 0.0);
return best_child->action_;
}

Expand All @@ -150,7 +171,7 @@ class Mcts {
}

void execute_round() {
Node* node = select_node(root_);
Node* node = select_node(root_.get());
double reward = rollout_(node->state_);
backpropogate(node, reward);
}
Expand All @@ -160,57 +181,31 @@ class Mcts {
if (node->is_fully_expanded()) {
node = get_best_child(node, exploration_constant_);
} else {
return expand(node);
return node->expand();
}
}
return node;
}

Node* expand(Node* node) {
return nullptr;
/* actions = node.state.getPossibleActions()
for action in actions:
if action not in node.children:
newNode = treeNode(node.state.takeAction(action), node)
node.children[action] = newNode
if len(actions) == len(node.children):
node.isFullyExpanded = True
return newNode
raise Exception("Should never reach here") */
}

Node* get_best_child(Node* node, double exploration_value) {
Node* get_best_child(Node* node, double exploration_value) const {
return *std::max_element(node->children_.begin(), node->children_.end(),
[exploration_value](Node* a, Node* b) {
if (a == nullptr && b == nullptr) {
return false;
}
if (a == nullptr || b == nullptr) {
return a == nullptr;
}
return a->ucb(exploration_value) <
b->ucb(exploration_value);
});
// bestValue = float("-inf")
/* bestNodes = []
for child in node.children.values():
nodeValue = node.state.getCurrentPlayer() * child.totalReward /
child.numVisits + explorationValue * math.sqrt( 2 * math.log(node.numVisits) /
child.numVisits)
if nodeValue > bestValue: bestValue = nodeValue bestNodes =
[child] elif nodeValue == bestValue: bestNodes.append(child) return
random.choice(bestNodes) */
}

private:
float exploration_constant_ = 2.0;
std::function<double(const SnakeMctsAdapter&)> rollout_;
Node* root_;
std::unique_ptr<Node> root_;
};
/*
def randomPolicy(state):
while not state.isTerminal():
try:
action = random.choice(state.getPossibleActions())
except IndexError:
raise Exception("Non-terminal state has no possible actions: " +
str(state)) state = state.takeAction(action) return state.getReward()
*/

} // namespace snake

Expand Down
6 changes: 5 additions & 1 deletion tensorflow/cc/examples/mcts_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ TEST(Mcts, MctsTest) {
SnakeMctsAdapter adapter(state);
Mcts mcts(random_rollout);

mcts.Search(adapter);
auto d = mcts.Search(adapter);
LOG(ERROR) << Direction_Name(d);

std::vector<float> a = {-1.0, -0.89};
LOG(ERROR) << *std::max_element(a.begin(), a.end());
}

} // namespace snake
8 changes: 8 additions & 0 deletions tensorflow/cc/examples/snake.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ class SnakeBoard {
bool valid_move(Direction d) const {
return board.is_unoccupied(player.peek(d));
}
bool has_move() const {
for (int i = Direction_MIN; i < Direction_ARRAYSIZE; ++i) {
if (board.is_unoccupied(player.peek(static_cast<Direction>(i)))) {
return true;
}
}
return false;
}
};

PlayerView p1_view() const { return {p1_, p2_, *this}; }
Expand Down

0 comments on commit 648dad9

Please sign in to comment.