forked from dmlc/ps-lite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_kv_app_multi_workers.cc
80 lines (71 loc) · 1.78 KB
/
test_kv_app_multi_workers.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <cmath>
#include "ps/ps.h"
using namespace ps;
void StartServer() {
if (!IsServer()) return;
auto server = new KVServer<float>(0);
server->set_request_handle(KVServerDefaultHandle<float>());
RegisterExitCallback([server](){ delete server; });
}
void RunWorker(int customer_id) {
Start(customer_id);
if (!IsWorker()) {
return;
}
KVWorker<float> kv(0, customer_id);
// init
int num = 10000;
std::vector<Key> keys(num);
std::vector<float> vals(num);
int rank = MyRank();
srand(rank + 7);
for (int i = 0; i < num; ++i) {
keys[i] = kMaxKey / num * i + customer_id;
vals[i] = (rand() % 1000);
}
// push
int repeat = 50;
std::vector<int> ts;
for (int i = 0; i < repeat; ++i) {
ts.push_back(kv.Push(keys, vals));
// to avoid too frequency push, which leads huge memory usage
if (i > 10) kv.Wait(ts[ts.size()-10]);
}
for (int t : ts) kv.Wait(t);
// pull
std::vector<float> rets;
kv.Wait(kv.Pull(keys, &rets));
// pushpull
std::vector<float> outs;
for (int i = 0; i < repeat; ++i) {
kv.Wait(kv.PushPull(keys, vals, &outs));
}
float res = 0;
float res2 = 0;
for (int i = 0; i < num; ++i) {
res += fabs(rets[i] - vals[i] * repeat);
res += fabs(outs[i] - vals[i] * 2 * repeat);
}
CHECK_LT(res / repeat, 1e-5);
CHECK_LT(res2 / (2 * repeat), 1e-5);
LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
// stop system
Finalize(customer_id, true);
}
int main(int argc, char *argv[]) {
// start system
bool isWorker = (strcmp(argv[1], "worker") == 0);
if (!isWorker) {
Start(0);
// setup server nodes
StartServer();
Finalize(0, true);
return 0;
}
// run worker nodes
std::thread t0(RunWorker, 0);
std::thread t1(RunWorker, 1);
t0.join();
t1.join();
return 0;
}