-
Notifications
You must be signed in to change notification settings - Fork 31
/
ecfr.h
105 lines (97 loc) · 2.98 KB
/
ecfr.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#ifndef _ECFR_H_
#define _ECFR_H_
#include <memory>
class ECFRNode;
class ECFRThread;
class BettingAbstraction;
class Buckets;
class CardAbstraction;
class CFRConfig;
class Node;
class Reader;
class Writer;
class ECFRNode {
public:
ECFRNode(void) {}
ECFRNode(Node *node, const Buckets &buckets);
~ECFRNode(void) {}
bool Terminal(void) const {return terminal_;}
bool Showdown(void) const {return showdown_;}
int Street(void) const {return st_;}
int PlayerActing(void) const {return player_acting_;}
int NumSuccs(void) const {return num_succs_;}
int LastBetTo(void) const {return last_bet_to_;}
ECFRNode *IthSucc(int i) const {return succs_[i].get();}
double *Regrets(void) {return regrets_.get();}
int *Sumprobs(void) {return sumprobs_.get();}
private:
bool terminal_;
bool showdown_;
int st_;
int player_acting_;
int num_succs_;
int last_bet_to_;
std::unique_ptr<std::unique_ptr<ECFRNode> []> succs_;
std::unique_ptr<double []> regrets_;
std::unique_ptr<int []> sumprobs_;
};
class ECFRThread {
public:
ECFRThread(const CFRConfig &cfr_config, const Buckets &buckets, ECFRNode *root, int seed,
int batch_size, const int *board_table, int num_raw_boards,
unsigned long long int *total_its, int thread_index, int num_threads);
~ECFRThread(void) {}
void Run(void);
void RunThread(void);
void Join(void);
private:
double Process(ECFRNode *node);
void Deal(void);
const Buckets &buckets_;
ECFRNode *root_;
int batch_size_;
const int *board_table_;
int num_raw_boards_;
unsigned long long int *total_its_;
int thread_index_;
int num_threads_;
std::unique_ptr<int []> canon_bds_;
std::unique_ptr<int []> hole_cards_;
std::unique_ptr<int []> hi_cards_;
std::unique_ptr<int []> lo_cards_;
std::unique_ptr<int []> hvs_;
std::unique_ptr<int []> hand_buckets_;
struct drand48_data rand_buf_;
int it_;
int p_;
int p1_outcome_;
pthread_t pthread_id_;
};
class ECFR {
public:
ECFR(const CardAbstraction &ca, const BettingAbstraction &ba, const CFRConfig &cc,
const Buckets &buckets, int num_threads);
~ECFR(void) {}
void Run(int start_batch_index, int end_batch_index, int batch_size, int save_interval);
private:
void ReadRegrets(ECFRNode *node, std::unique_ptr<Reader> *readers);
void ReadSumprobs(ECFRNode *node, std::unique_ptr<Reader> *readers);
void Read(int batch_index);
void WriteRegrets(ECFRNode *node, std::unique_ptr<Writer> *writers);
void WriteSumprobs(ECFRNode *node, std::unique_ptr<Writer> *writers);
void Write(int batch_index);
void Run(void);
void RunBatch(int batch_index, int batch_size);
const CardAbstraction &card_abstraction_;
const BettingAbstraction &betting_abstraction_;
const CFRConfig &cfr_config_;
const Buckets &buckets_;
std::unique_ptr<ECFRNode> root_;
std::unique_ptr<int []> board_table_;
int num_raw_boards_;
struct drand48_data rand_buf_;
int num_cfr_threads_;
std::unique_ptr<std::unique_ptr<ECFRThread> []> cfr_threads_;
unsigned long long int total_its_;
};
#endif