Skip to content

Commit

Permalink
Use a naive memory allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
changlan authored and eric-haibin-lin committed Aug 18, 2019
1 parent c5061ae commit dcfb522
Showing 1 changed file with 110 additions and 107 deletions.
217 changes: 110 additions & 107 deletions src/rdma_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <rdma/rdma_cma.h>

#include <algorithm>
#include <map>
#include <queue>
#include <set>
#include <string>
Expand Down Expand Up @@ -50,60 +51,69 @@ static inline T align_ceil(T v, T align) {
return align_floor(v + align - 1, align);
}

class MRPool {
class SimpleMempool {
public:
explicit MRPool(const size_t s, struct ibv_pd *pd) : chunk_size(s), pd(pd) { CHECK(pd); }
~MRPool() { purge_memory(); }

bool purge_memory() {
bool ret = !(free_list.empty() && used_list.empty());
for (std::set<struct ibv_mr *>::iterator pos = free_list.begin(); pos != free_list.end();
++pos) {
struct ibv_mr *mr = *pos;
std::free(mr->addr);
CHECK_EQ(ibv_dereg_mr(mr), 0);
explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x80000000) {
char *p = reinterpret_cast<char *>(malloc(size));
CHECK(p);
CHECK(mr = ibv_reg_mr(pd, p, size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE));
free_list.emplace(size, p);
}

~SimpleMempool() {
free(mr->addr);
CHECK_EQ(ibv_dereg_mr(mr), 0);
}

char *Alloc(size_t size) {
std::lock_guard<std::mutex> lk(mu_);

if (size == 0) {
return nullptr;
}
free_list.clear();
for (std::set<struct ibv_mr *>::iterator pos = used_list.begin(); pos != used_list.end();
++pos) {
struct ibv_mr *mr = *pos;
std::free(mr->addr);
CHECK(ibv_dereg_mr(mr));

auto it = free_list.lower_bound(size);
CHECK_NE(free_list.end(), it);

char *ret = it->second;

CHECK_GE(it->first, size);

size_t space_left = it->first - size;

used_list.emplace(ret, size);
free_list.erase(it);

if (space_left) {
free_list.emplace(space_left, ret + size);
}
used_list.clear();

return ret;
}

size_t get_next_size() const { return 1; }
void set_next_size(const size_t) {}
size_t get_requested_size() const { return chunk_size; }

struct ibv_mr *malloc() {
void *addr;
struct ibv_mr *mr;
if (free_list.empty()) {
addr = std::malloc(chunk_size);
mr = ibv_reg_mr(pd, addr, chunk_size, IBV_ACCESS_LOCAL_WRITE);
CHECK(mr) << strerror(errno);
} else {
mr = *free_list.begin();
free_list.erase(free_list.begin());
void Free(char *addr) {
std::lock_guard<std::mutex> lk(mu_);

if (!addr) {
return;
}
used_list.insert(mr);
return mr;
}

void free(struct ibv_mr *const chunk) {
CHECK_EQ(used_list.count(chunk), 1);
CHECK_EQ(free_list.count(chunk), 0);
used_list.erase(chunk);
free_list.insert(chunk);
auto it = used_list.find(addr);
CHECK_NE(used_list.end(), it);

size_t size = it->second;
used_list.erase(it);
free_list.emplace(size, addr);
}

uint32_t LocalKey() const { return mr->lkey; }
uint32_t RemoteKey() const { return mr->rkey; }

private:
size_t chunk_size;
std::set<struct ibv_mr *> free_list, used_list;
struct ibv_pd *const pd;
std::mutex mu_;
std::multimap<size_t, char *> free_list;
std::unordered_map<char *, size_t> used_list;
struct ibv_mr *mr;
};

enum MessageTypes : uint32_t {
Expand Down Expand Up @@ -132,23 +142,23 @@ struct alignas(64) WRContext {
};

struct alignas(64) BufferContext {
struct ibv_mr *mr;
char *buffer;
size_t meta_len;
size_t data_num;
size_t data_len[kMaxDataFields];
};

struct alignas(64) LocalBufferContext {
size_t meta_len;
void *meta_buf;
char *meta_buf;
std::vector<SArray<char>> data;
};

struct alignas(64) MessageBuffer {
size_t meta_len;
void *meta_buf;
std::vector<SArray<char>> data;
std::vector<struct ibv_mr *> regions;
size_t data_len;
char *meta_buf;
char *data_buf;
};

struct RequestContext {
Expand Down Expand Up @@ -329,7 +339,7 @@ class RDMAVan : public Van {
i.second.reset();
}

message_mempool_.reset();
mempool_.reset();

cm_event_polling_thread_->join();
cm_event_polling_thread_.reset();
Expand Down Expand Up @@ -437,7 +447,7 @@ class RDMAVan : public Van {
if (remote_id == my_node_.id) {
LocalBufferContext *buf_ctx = new LocalBufferContext();
buf_ctx->meta_len = meta.ByteSize();
buf_ctx->meta_buf = malloc(buf_ctx->meta_len);
buf_ctx->meta_buf = mempool_->Alloc(buf_ctx->meta_len);
meta.SerializeToArray(buf_ctx->meta_buf, buf_ctx->meta_len);
buf_ctx->data = msg.data;
recv_buffers_.Push(
Expand All @@ -448,18 +458,29 @@ class RDMAVan : public Van {
CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
RDMAEndpoint *endpoint = endpoints_[remote_id].get();

MessageBuffer *msg_buf = new MessageBuffer();
MessageBuffer *msg_buf =
reinterpret_cast<MessageBuffer *>(mempool_->Alloc(sizeof(MessageBuffer)));

CHECK(meta.ByteSize());
msg_buf->meta_len = meta.ByteSize();
msg_buf->meta_buf = malloc(msg_buf->meta_len);
msg_buf->data = msg.data;
msg_buf->data_len = msg.meta.data_size;
msg_buf->meta_buf = mempool_->Alloc(msg_buf->meta_len);
msg_buf->data_buf = mempool_->Alloc(msg_buf->data_len);
meta.SerializeToArray(msg_buf->meta_buf, msg_buf->meta_len);

struct ibv_mr *mr = message_mempool_->malloc();
CHECK(mr);
char *cur = reinterpret_cast<char *>(msg_buf->data_buf);
for (size_t i = 0; i < msg.data.size(); ++i) {
size_t size = msg.data[i].size();
memcpy(cur, msg.data[i].data(), size);
cur += size;
}

char *p = mempool_->Alloc(sizeof(RendezvousStart));
CHECK(p);

RendezvousStart *req = reinterpret_cast<RendezvousStart *>(mr->addr);
RendezvousStart *req = reinterpret_cast<RendezvousStart *>(p);
req->meta_len = meta.ByteSize();

for (size_t i = 0; i < msg.data.size(); ++i) {
req->data_len[i] = msg.data[i].size();
}
Expand All @@ -469,12 +490,12 @@ class RDMAVan : public Van {
struct ibv_sge sge;
sge.addr = reinterpret_cast<uint64_t>(req);
sge.length = sizeof(RendezvousStart);
sge.lkey = mr->lkey;
sge.lkey = mempool_->LocalKey();

struct ibv_send_wr wr, *bad_wr = nullptr;
memset(&wr, 0, sizeof(wr));

wr.wr_id = reinterpret_cast<uint64_t>(mr);
wr.wr_id = reinterpret_cast<uint64_t>(req);
wr.opcode = IBV_WR_SEND_WITH_IMM;
wr.next = nullptr;

Expand All @@ -501,10 +522,10 @@ class RDMAVan : public Van {
reinterpret_cast<LocalBufferContext *>(std::get<BufferContext *>(notification));
msg->meta.recver = my_node_.id;
msg->meta.sender = my_node_.id;
UnpackMeta(reinterpret_cast<char *>(buffer_ctx->meta_buf), buffer_ctx->meta_len, &msg->meta);
UnpackMeta(buffer_ctx->meta_buf, buffer_ctx->meta_len, &msg->meta);
msg->data = buffer_ctx->data;
int total_len = buffer_ctx->meta_len + msg->meta.data_size;
free(buffer_ctx->meta_buf);
mempool_->Free(buffer_ctx->meta_buf);
delete buffer_ctx;
return total_len;
}
Expand All @@ -516,7 +537,7 @@ class RDMAVan : public Van {
msg->meta.recver = my_node_.id;
msg->meta.sender = endpoint->node_id;

const char *cur = reinterpret_cast<char *>(buffer_ctx->mr->addr);
const char *cur = reinterpret_cast<char *>(buffer_ctx->buffer);
UnpackMeta(cur, buffer_ctx->meta_len, &msg->meta);
total_len += buffer_ctx->meta_len;
uint64_t data_num = buffer_ctx->data_num;
Expand All @@ -531,8 +552,8 @@ class RDMAVan : public Van {
total_len += len;
}

CHECK_EQ(ibv_dereg_mr(buffer_ctx->mr), 0);
free(buffer_ctx);
mempool_->Free(buffer_ctx->buffer);
mempool_->Free(reinterpret_cast<char *>(buffer_ctx));

return total_len;
}
Expand All @@ -545,7 +566,7 @@ class RDMAVan : public Van {
pd_ = ibv_alloc_pd(context_);
CHECK(pd_) << "Failed to allocate protection domain";

message_mempool_.reset(new MRPool(kMempoolChunkSize, pd_));
mempool_.reset(new SimpleMempool(pd_));

comp_event_channel_ = ibv_create_comp_channel(context_);
cq_ = ibv_create_cq(context_, kMaxConcurrentWorkRequest * 2, NULL, comp_event_channel_, 0);
Expand All @@ -571,17 +592,15 @@ class RDMAVan : public Van {
switch (wc[i].opcode) {
case IBV_WC_SEND: {
// LOG(INFO) << "opcode: IBV_WC_SEND";
struct ibv_mr *mr = reinterpret_cast<struct ibv_mr *>(wc[i].wr_id);
message_mempool_->free(mr);
char *p = reinterpret_cast<char *>(wc[i].wr_id);
mempool_->Free(p);
} break;
case IBV_WC_RDMA_WRITE: {
// LOG(INFO) << "opcode: IBV_WC_RDMA_WRITE";
MessageBuffer *msg_buf = reinterpret_cast<MessageBuffer *>(wc[i].wr_id);
free(msg_buf->meta_buf);
for (auto &mr : msg_buf->regions) {
CHECK_EQ(ibv_dereg_mr(mr), 0);
}
delete msg_buf;
mempool_->Free(msg_buf->meta_buf);
mempool_->Free(msg_buf->data_buf);
mempool_->Free(reinterpret_cast<char *>(msg_buf));
} break;
case IBV_WC_RECV_RDMA_WITH_IMM: {
// LOG(INFO) << "opcode: IBV_WC_RECV_RDMA_WITH_IMM";
Expand All @@ -606,7 +625,7 @@ class RDMAVan : public Van {
RendezvousStart *req = reinterpret_cast<RendezvousStart *>(mr->addr);

BufferContext *buf_ctx =
reinterpret_cast<BufferContext *>(malloc(sizeof(BufferContext)));
reinterpret_cast<BufferContext *>(mempool_->Alloc(sizeof(BufferContext)));

uint64_t len = req->meta_len;
buf_ctx->meta_len = len;
Expand All @@ -616,32 +635,31 @@ class RDMAVan : public Van {
len += req->data_len[i];
}

void *buffer = malloc(len);
CHECK(buffer) << "malloc for " << len << " bytes, data_num: " << req->data_num;
struct ibv_mr *rx_mr =
ibv_reg_mr(pd_, buffer, len, IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
CHECK(rx_mr);
buf_ctx->mr = rx_mr;
char *buffer = mempool_->Alloc(len);
CHECK(buffer) << "Alloc for " << len << " bytes, data_num: " << req->data_num;

buf_ctx->buffer = buffer;

uint64_t origin_addr = req->origin_addr;

struct ibv_mr *mr = message_mempool_->malloc();
CHECK(mr);
RendezvousReply *resp = reinterpret_cast<RendezvousReply *>(mr->addr);
char *p = mempool_->Alloc(sizeof(RendezvousReply));
CHECK(p);
RendezvousReply *resp = reinterpret_cast<RendezvousReply *>(p);

resp->addr = reinterpret_cast<uint64_t>(buffer);
resp->rkey = rx_mr->rkey;
resp->rkey = mempool_->RemoteKey();
resp->origin_addr = origin_addr;
resp->idx = addr_pool_.StoreAddress(buf_ctx);

struct ibv_sge sge;
sge.addr = reinterpret_cast<uint64_t>(resp);
sge.length = sizeof(RendezvousReply);
sge.lkey = mr->lkey;
sge.lkey = mempool_->LocalKey();

struct ibv_send_wr wr, *bad_wr = nullptr;
memset(&wr, 0, sizeof(wr));

wr.wr_id = reinterpret_cast<uint64_t>(mr);
wr.wr_id = reinterpret_cast<uint64_t>(resp);
wr.opcode = IBV_WR_SEND_WITH_IMM;
wr.next = nullptr;

Expand All @@ -664,30 +682,15 @@ class RDMAVan : public Van {

MessageBuffer *msg_buf = reinterpret_cast<MessageBuffer *>(origin_addr);

struct ibv_sge sge[msg_buf->data.size() + 1];

struct ibv_mr *meta_mr = ibv_reg_mr(pd_, msg_buf->meta_buf, msg_buf->meta_len, 0);
CHECK(meta_mr) << strerror(errno);
struct ibv_sge sge[2];

sge[0].addr = reinterpret_cast<uint64_t>(msg_buf->meta_buf);
sge[0].length = msg_buf->meta_len;
sge[0].lkey = meta_mr->lkey;

msg_buf->regions.push_back(meta_mr);

size_t num_sge = 1;
for (size_t i = 0; i < msg_buf->data.size(); ++i) {
void *p = msg_buf->data[i].ptr().get();
size_t len = msg_buf->data[i].size();
if (len > 0) {
struct ibv_mr *mr = ibv_reg_mr(pd_, p, len, 0);
sge[num_sge].addr = reinterpret_cast<uint64_t>(p);
sge[num_sge].length = len;
sge[num_sge].lkey = mr->lkey;
msg_buf->regions.push_back(mr);
++num_sge;
}
}
sge[0].lkey = mempool_->LocalKey();

sge[1].addr = reinterpret_cast<uint64_t>(msg_buf->data_buf);
sge[1].length = msg_buf->data_len;
sge[1].lkey = mempool_->LocalKey();

struct ibv_send_wr wr, *bad_wr = nullptr;
memset(&wr, 0, sizeof(wr));
Expand All @@ -700,7 +703,7 @@ class RDMAVan : public Van {

wr.send_flags = IBV_SEND_SIGNALED;
wr.sg_list = sge;
wr.num_sge = num_sge;
wr.num_sge = 1 + (msg_buf->data_len > 0 ? 1 : 0);

wr.wr.rdma.remote_addr = remote_addr;
wr.wr.rdma.rkey = rkey;
Expand Down Expand Up @@ -918,7 +921,7 @@ class RDMAVan : public Van {
}

AddressPool<BufferContext> addr_pool_;
std::unique_ptr<MRPool> message_mempool_;
std::unique_ptr<SimpleMempool> mempool_;

struct rdma_cm_id *listener_ = nullptr;
std::atomic<bool> should_stop_;
Expand Down

0 comments on commit dcfb522

Please sign in to comment.