Skip to content

Commit

Permalink
PushPull API (dmlc#150)
Browse files Browse the repository at this point in the history
* Add PushPull API call

* Add test cases for PushPull

* Lint checker error fixes

* FIX: Merge conflict
  • Loading branch information
anandj91 authored and eric-haibin-lin committed Sep 1, 2019
1 parent 2c8ed25 commit 8e8545e
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 17 deletions.
4 changes: 3 additions & 1 deletion include/ps/internal/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct Meta {
/** \brief default constructor */
Meta() : head(kEmpty), app_id(kEmpty), customer_id(kEmpty),
timestamp(kEmpty), sender(kEmpty), recver(kEmpty),
request(false), push(false), simple_app(false) {}
request(false), push(false), pull(false), simple_app(false) {}
std::string DebugString() const {
std::stringstream ss;
if (sender == Node::kEmpty) {
Expand Down Expand Up @@ -183,6 +183,8 @@ struct Meta {
bool request;
/** \brief whether or not a push message */
bool push;
/** \brief whether or not a pull message */
bool pull;
/** \brief whether or not it's for SimpleApp */
bool simple_app;
/** \brief an string body */
Expand Down
121 changes: 107 additions & 14 deletions include/ps/kv_app.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class KVWorker : public SimpleApp {
* KVWorker<float> w;
* std::vector<Key> keys = {1, 3};
* std::vector<float> vals;
* ps.Pull(keys, &vals);
* w.Pull(keys, &vals);
* \endcode
*
* It's a non-blocking call. The actual pulling is finished,
Expand All @@ -149,7 +149,66 @@ class KVWorker : public SimpleApp {
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
return Pull_(SArray<Key>(keys), vals, lens, cmd, cb);
SArray<Key> skeys(keys);
int ts = AddPullCB(skeys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = skeys;
Send(ts, false, true, cmd, kvs);
return ts;
}

/**
* \brief Pushes and Pulls a list of key-value pairs to and from the server
* nodes.
*
* This function pushes the values of the keys specified in \a keys to the
* server nodes and subsequently pulls and updates the values in \a vals.
*
* Sample usage: the following code pushes and pulls the values of keys
* \a 1 and \a 3 to and from the server nodes.
* \code
* KVWorker<float> w;
* std::vector<Key> keys = {1, 3};
* std::vector<float> vals;
* w.PushPull(keys, &vals);
* \endcode
*
* It's a non-blocking call. The actual pulling is finished,
* namely \a vals (and \a lens) is filled with pulled values, only
* if \ref Wait returns or the callback is called.
*
* @param keys a list of keys, must be unique and sorted in increasing order
* @param vals the according values
* @param outs the buffer for the pulled values. It can be 0 size.
* @param lens optional buffer for the value length. If set, it can be 0 size.
* @param cmd an optional command sent to the servers
* @param cb the callback which is called when the pull is finished.
* @return the timestamp of this request
*/
int PushPull(const std::vector<Key>& keys,
const std::vector<Val>& vals,
std::vector<Val>* outs,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
CHECK_NOTNULL(outs);
if (outs->empty())
outs->resize(vals.size());
else
CHECK_EQ(vals.size(), outs->size());

SArray<Key> skeys(keys);
SArray<Val> svals(vals);
auto souts = new SArray<Val>(outs->data(), outs->size());
SArray<int>* slens = lens ?
new SArray<int>(lens->data(), lens->size()) : nullptr;
int ts = ZPushPull(skeys, svals, souts, slens, cmd,
[this, cb, souts, slens]() {
delete souts;
delete slens;
if (cb) cb();
});
return ts;
}

/**
Expand Down Expand Up @@ -185,7 +244,7 @@ class KVWorker : public SimpleApp {
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
Send(ts, true, cmd, kvs);
Send(ts, true, false, cmd, kvs);
return ts;
}

Expand All @@ -202,7 +261,35 @@ class KVWorker : public SimpleApp {
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
return Pull_(keys, vals, lens, cmd, cb);
int ts = AddPullCB(keys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
Send(ts, false, true, cmd, kvs);
return ts;
}

/**
* \brief zero-copy PushPull
*
* This function is similar to \ref PushPull except that all data
* will not be copied into system for better performance. It is the caller's
* responsibility to keep the content to be not changed before actually
* finished.
*/
int ZPushPull(const SArray<Key>& keys,
const SArray<Val>& vals,
SArray<Val>* outs,
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
int ts = AddPullCB(keys, outs, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
if (lens)
kvs.lens = *lens;
Send(ts, true, true, cmd, kvs);
return ts;
}
using SlicedKVs = std::vector<std::pair<bool, KVPairs<Val>>>;
/**
Expand All @@ -228,7 +315,7 @@ class KVWorker : public SimpleApp {
* \brief internal pull, C/D can be either SArray or std::vector
*/
template <typename C, typename D>
int Pull_(const SArray<Key>& keys, C* vals, D* lens,
int AddPullCB(const SArray<Key>& keys, C* vals, D* lens,
int cmd, const Callback& cb);
/**
* \brief add a callback for a request. threadsafe.
Expand All @@ -250,9 +337,10 @@ class KVWorker : public SimpleApp {
* \brief send the kv list to all servers
* @param timestamp the timestamp of the request
* @param push whether or not it is a push request
* @param push whether or not it is a pull request
* @param cmd command
*/
void Send(int timestamp, bool push, int cmd, const KVPairs<Val>& kvs);
void Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs);
/** \brief internal receive handle */
void Process(const Message& msg);
/** \brief default kv slicer */
Expand All @@ -276,6 +364,8 @@ struct KVMeta {
int cmd;
/** \brief whether or not this is a push request */
bool push;
/** \brief whether or not this is a pull request */
bool pull;
/** \brief sender's node id */
int sender;
/** \brief the associated timestamp */
Expand Down Expand Up @@ -340,7 +430,7 @@ struct KVServerDefaultHandle {
const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
size_t n = req_data.keys.size();
KVPairs<Val> res;
if (req_meta.push) {
if (!req_meta.pull) {
CHECK_EQ(n, req_data.vals.size());
} else {
res.keys = req_data.keys; res.vals.resize(n);
Expand All @@ -349,7 +439,8 @@ struct KVServerDefaultHandle {
Key key = req_data.keys[i];
if (req_meta.push) {
store[key] += req_data.vals[i];
} else {
}
if (req_meta.pull) {
res.vals[i] = store[key];
}
}
Expand All @@ -369,6 +460,7 @@ void KVServer<Val>::Process(const Message& msg) {
KVMeta meta;
meta.cmd = msg.meta.head;
meta.push = msg.meta.push;
meta.pull = msg.meta.pull;
meta.sender = msg.meta.sender;
meta.timestamp = msg.meta.timestamp;
meta.customer_id = msg.meta.customer_id;
Expand All @@ -395,6 +487,7 @@ void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) {
msg.meta.customer_id = req.customer_id;
msg.meta.request = false;
msg.meta.push = req.push;
msg.meta.pull = req.pull;
msg.meta.head = req.cmd;
msg.meta.timestamp = req.timestamp;
msg.meta.recver = req.sender;
Expand Down Expand Up @@ -466,7 +559,7 @@ void KVWorker<Val>::DefaultSlicer(
}

template <typename Val>
void KVWorker<Val>::Send(int timestamp, bool push, int cmd, const KVPairs<Val>& kvs) {
void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) {
// slice the message
SlicedKVs sliced;
slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced);
Expand All @@ -489,6 +582,7 @@ void KVWorker<Val>::Send(int timestamp, bool push, int cmd, const KVPairs<Val>&
msg.meta.customer_id = obj_->customer_id();
msg.meta.request = true;
msg.meta.push = push;
msg.meta.pull = pull;
msg.meta.head = cmd;
msg.meta.timestamp = timestamp;
msg.meta.recver = Postoffice::Get()->ServerRankToID(i);
Expand All @@ -512,7 +606,7 @@ void KVWorker<Val>::Process(const Message& msg) {
}
// store the data for pulling
int ts = msg.meta.timestamp;
if (!msg.meta.push && msg.data.size()) {
if (msg.meta.pull) {
CHECK_GE(msg.data.size(), (size_t)2);
KVPairs<Val> kvs;
kvs.keys = msg.data[0];
Expand Down Expand Up @@ -548,8 +642,9 @@ void KVWorker<Val>::RunCallback(int timestamp) {

template <typename Val>
template <typename C, typename D>
int KVWorker<Val>::Pull_(
const SArray<Key>& keys, C* vals, D* lens, int cmd, const Callback& cb) {
int KVWorker<Val>::AddPullCB(
const SArray<Key>& keys, C* vals, D* lens, int cmd,
const Callback& cb) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, [this, ts, keys, vals, lens, cb]() mutable {
mu_.lock();
Expand Down Expand Up @@ -604,8 +699,6 @@ int KVWorker<Val>::Pull_(
if (cb) cb();
});

KVPairs<Val> kvs; kvs.keys = keys;
Send(ts, false, cmd, kvs);
return ts;
}

Expand Down
2 changes: 2 additions & 0 deletions src/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ message PBMeta {
optional int32 customer_id = 10;
// whether or not a push message
optional bool push = 5;
// whether or not a pull message
optional bool pull = 12;
// whether or not it's for SimpleApp
optional bool simple_app = 6 [default = false];
// message.data_size
Expand Down
2 changes: 2 additions & 0 deletions src/van.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) {
if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp);
if (meta.body.size()) pb.set_body(meta.body);
pb.set_push(meta.push);
pb.set_pull(meta.pull);
pb.set_request(meta.request);
pb.set_simple_app(meta.simple_app);
pb.set_customer_id(meta.customer_id);
Expand Down Expand Up @@ -520,6 +521,7 @@ void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) {
meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty;
meta->request = pb.request();
meta->push = pb.push();
meta->pull = pb.pull();
meta->simple_app = pb.simple_app();
meta->body = pb.body();
meta->customer_id = pb.customer_id();
Expand Down
12 changes: 11 additions & 1 deletion tests/test_kv_app.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ void RunWorker() {
std::vector<float> rets;
kv.Wait(kv.Pull(keys, &rets));

// pushpull
std::vector<float> outs;
for (int i = 0; i < repeat; ++i) {
// PushPull on the same keys should be called serially
kv.Wait(kv.PushPull(keys, vals, &outs));
}

float res = 0;
float res2 = 0;
for (int i = 0; i < num; ++i) {
res += std::fabs(rets[i] - vals[i] * repeat);
res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
}
CHECK_LT(res / repeat, 1e-5);
LL << "error: " << res / repeat;
CHECK_LT(res2 / (2 * repeat), 1e-5);
LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}

int main(int argc, char *argv[]) {
Expand Down
11 changes: 10 additions & 1 deletion tests/test_kv_app_multi_workers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,21 @@ void RunWorker(int customer_id) {
std::vector<float> rets;
kv.Wait(kv.Pull(keys, &rets));

// pushpull
std::vector<float> outs;
for (int i = 0; i < repeat; ++i) {
kv.Wait(kv.PushPull(keys, vals, &outs));
}

float res = 0;
float res2 = 0;
for (int i = 0; i < num; ++i) {
res += fabs(rets[i] - vals[i] * repeat);
res += fabs(outs[i] - vals[i] * 2 * repeat);
}
CHECK_LT(res / repeat, 1e-5);
LL << "error: " << res / repeat;
CHECK_LT(res2 / (2 * repeat), 1e-5);
LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
// stop system
Finalize(customer_id, true);
}
Expand Down

0 comments on commit 8e8545e

Please sign in to comment.