Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rdma connection improvement(best result ever) #107

Merged
merged 6 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions infinistore/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def parse_args():
help="IB or Ethernet, default IB",
type=str,
)
parser.add_argument(
"--steps",
required=False,
type=int,
default=32,
help="number of steps, default 32",
)
return parser.parse_args()


Expand Down Expand Up @@ -119,7 +126,7 @@ def run(args):

block_size = args.block_size * 1024 // 4
num_of_blocks = args.size * 1024 * 1024 // (args.block_size * 1024)
keys = [generate_random_string(250) for i in range(num_of_blocks)]
keys = [generate_random_string(150) for i in range(num_of_blocks)]
with infinistore.DisableTorchCaching():
src_tensor = torch.rand(
num_of_blocks * block_size, device=src_device, dtype=torch.float32
Expand All @@ -146,12 +153,13 @@ def run(args):
read_sum = 0.0

for _ in range(args.iteration):
random.seed(time.time())

if args.rdma:
remote_addrs = conn.allocate_rdma(keys, block_size * 4)

steps = int(
32
) # simulate we have <steps> layers, this steps should be less then MAX_WR_SIZE
steps = args.steps
# simulate we have <steps> layers, this steps should be less then MAX_WR_SIZE
while len(blocks) % steps != 0 and steps > 1:
steps = int(steps / 2)
print(f"\nSimulate {steps} layers, running\n")
Expand Down
10 changes: 10 additions & 0 deletions infinistore/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import logging
import subprocess
import os


# disable standard logging, we will use our own logger
Expand Down Expand Up @@ -125,6 +126,12 @@ def parse_args():
return parser.parse_args()


def prevent_oom():
pid = os.getpid()
with open(f"/proc/{pid}/oom_score_adj", "w") as f:
f.write("-1000")


def main():
args = parse_args()
config = ServerConfig(
Expand Down Expand Up @@ -165,6 +172,9 @@ def main():
]
)

prevent_oom()
Logger.info("set oom_score_adj to -1000 to prevent OOM")

http_config = uvicorn.Config(
app, host="0.0.0.0", port=config.manage_port, loop="uvloop"
)
Expand Down
125 changes: 110 additions & 15 deletions src/infinistore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <uv.h>

#include <chrono>
#include <deque>
#include <iostream>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -93,6 +94,8 @@ struct Client {
// RDMA send buffer
char *send_buffer_ = NULL;
struct ibv_mr *send_mr_ = NULL;
int outstanding_rdma_writes_ = 0;
std::deque<std::pair<struct ibv_send_wr *, struct ibv_sge *>> outstanding_rdma_writes_queue_;

// TCP send buffer
char *tcp_send_buffer_ = NULL;
Expand Down Expand Up @@ -271,7 +274,6 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) {
}
else if (wc.opcode ==
IBV_WC_RECV_RDMA_WITH_IMM) { // write cache: we alreay have all data now.

// client should not use WRITE_WITH_IMM to notify.
// it should use COMMIT message to notify.
WARN("WRITE_WITH_IMM is not supported in server side");
Expand All @@ -280,6 +282,30 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) {
return;
}
}
else if (wc.opcode == IBV_WC_RDMA_WRITE) {
// some RDMA write(read cache WRs) is finished

DEBUG("RDMA_WRITE done wr_id: {}", wc.wr_id);
assert(outstanding_rdma_writes_ >= 0);
outstanding_rdma_writes_ -= MAX_WR_BATCH;

if (!outstanding_rdma_writes_queue_.empty()) {
auto item = outstanding_rdma_writes_queue_.front();
struct ibv_send_wr *wrs = item.first;
struct ibv_sge *sges = item.second;
ibv_send_wr *bad_wr = nullptr;
DEBUG("IBV POST SEND, wr_id: {}", wrs[0].wr_id);
int ret = ibv_post_send(qp_, &wrs[0], &bad_wr);
if (ret) {
ERROR("Failed to post RDMA write {}", strerror(ret));
throw std::runtime_error("Failed to post RDMA write");
}
outstanding_rdma_writes_ += MAX_WR_BATCH;
delete[] wrs;
delete[] sges;
outstanding_rdma_writes_queue_.pop_front();
}
}
else {
ERROR("Unexpected wc opcode: {}", (int)wc.opcode);
}
Expand All @@ -306,6 +332,15 @@ int Client::allocate_rdma(const RemoteMetaRequest *req) {
// FIXME: rdma write should have a msg to update committed to true

const auto *key = req->keys()->Get(key_idx);

if (kv_map.count(key->str()) != 0) {
// WARN("rdma_write: Key already exists: {}", key->str());
// put fake addr, and send to client
blocks.push_back(FAKE_REMOTE_BLOCK);
key_idx++;
return;
}

auto ptr =
boost::intrusive_ptr<PTR>(new PTR(addr, block_size, pool_idx, false));

Expand Down Expand Up @@ -370,7 +405,7 @@ int Client::prepare_recv_rdma_request(int buf_idx) {
}

int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
INFO("do rdma read...");
INFO("do rdma read... num of keys: {}", remote_meta_req->keys()->size());

if (remote_meta_req->keys()->size() != remote_meta_req->remote_addrs()->size()) {
ERROR("keys size and remote_addrs size mismatch");
Expand Down Expand Up @@ -400,38 +435,74 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
blocks.push_back({.lkey = mm->get_lkey(ptr->pool_idx), .local_addr = (uintptr_t)ptr->ptr});
}

const size_t max_wr = 16;
struct ibv_send_wr wrs[max_wr];
struct ibv_sge sges[max_wr];
const size_t max_wr = MAX_WR_BATCH;
struct ibv_send_wr local_wrs[max_wr];
struct ibv_sge local_sges[max_wr];

struct ibv_send_wr *wrs = local_wrs;
struct ibv_sge *sges = local_sges;

size_t num_wr = 0;
bool wr_full = false;

if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) {
wr_full = true;
wrs = new struct ibv_send_wr[max_wr];
sges = new struct ibv_sge[max_wr];
}

for (size_t i = 0; i < remote_meta_req->keys()->size(); i++) {
sges[num_wr].addr = blocks[i].local_addr;
sges[num_wr].length = remote_meta_req->block_size();
sges[num_wr].lkey = blocks[i].lkey;

wrs[num_wr].wr_id = 1234;
wrs[num_wr].wr_id = i;
wrs[num_wr].opcode = (i == remote_meta_req->keys()->size() - 1) ? IBV_WR_RDMA_WRITE_WITH_IMM
: IBV_WR_RDMA_WRITE;
wrs[num_wr].sg_list = &sges[num_wr];
wrs[num_wr].num_sge = 1;
wrs[num_wr].send_flags = 0;
wrs[num_wr].wr.rdma.remote_addr = remote_meta_req->remote_addrs()->Get(i);
wrs[num_wr].wr.rdma.rkey = remote_meta_req->rkey();
wrs[num_wr].next = (num_wr == max_wr - 1 || i == remote_meta_req->keys()->size() - 1)
? nullptr
: &wrs[num_wr + 1];

wrs[num_wr].send_flags = (num_wr == max_wr - 1 || i == remote_meta_req->keys()->size() - 1)
? IBV_SEND_SIGNALED
: 0;

num_wr++;

// If we reach the maximum number of WRs, post them
if (num_wr == max_wr || i == remote_meta_req->keys()->size() - 1) {
struct ibv_send_wr *bad_wr = nullptr;
int ret = ibv_post_send(qp_, &wrs[0], &bad_wr);
if (ret) {
ERROR("Failed to post RDMA write {}", strerror(ret));
return -1;
if (!wr_full) {
struct ibv_send_wr *bad_wr = nullptr;
DEBUG("local write");
int ret = ibv_post_send(qp_, &wrs[0], &bad_wr);
if (ret) {
ERROR("Failed to post RDMA write {}", strerror(ret));
return -1;
}
outstanding_rdma_writes_ += max_wr;

// check if next iteration will exceed the limit
if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) {
wr_full = true;
}
}
else {
// if WR queue is full, we need to put them into queue
WARN(
"WR queue full: push into queue, len: {}, first wr_id: {}, last wr_id: {}, "
"last op code: {} ",
num_wr, wrs[0].wr_id, wrs[num_wr - 1].wr_id, wrs[num_wr - 1].opcode);
outstanding_rdma_writes_queue_.push_back({&wrs[0], &sges[0]});
}

if (wr_full) {
wrs = new struct ibv_send_wr[max_wr];
sges = new struct ibv_sge[max_wr];
}

num_wr = 0; // Reset the counter for the next batch
}
}
Expand Down Expand Up @@ -625,6 +696,8 @@ int Client::write_cache(const LocalMetaRequest *meta_req) {
global_config.num_stream);

void *d_ptr;
int return_code = TASK_ACCEPTED;

cudaIpcMemHandle_t ipc_handle = *(cudaIpcMemHandle_t *)meta_req->ipc_handle()->data();

CHECK_CUDA(cudaSetDevice(meta_req->device()));
Expand Down Expand Up @@ -658,6 +731,17 @@ int Client::write_cache(const LocalMetaRequest *meta_req) {
auto block = meta_req->blocks()->Get(key_idx);
DEBUG("key: {}, local_addr: {}, size : {}", block->key()->str(), (uintptr_t)addr,
block_size);

// deduplicate the key
const auto &key = block->key()->str();
if (kv_map.count(key) != 0) {
// this key could be commited or uncommitted, no mather what it is, we should skip
// it
WARN("local gpu write: Key already exists: {}, skip this key", key);
key_idx++;
return;
}

// we have global_config.num_stream streams, so we need to divide the tasks into streams
// and interleavely lanch cudaMemcpyAsync
int stream_idx = key_idx % global_config.num_stream;
Expand All @@ -682,12 +766,21 @@ int Client::write_cache(const LocalMetaRequest *meta_req) {
std::chrono::high_resolution_clock::now() - start)
.count());

int return_code = TASK_ACCEPTED;
if (!ret) {
ERROR("Failed to allocate memory");
return_code = OUT_OF_MEMORY;
}

if (tasks->size() == 0) {
// all keys are duplicated, do not start the async tasks
send_resp(return_code, NULL, 0);
reset_client_read_state();
return 0;
}

for (int i = 0; i < global_config.num_stream; i++) {
CHECK_CUDA(cudaEventRecord(events[i], cuda_streams[i]));
}
remain_++;

auto *finished = new std::atomic<int>;
Expand All @@ -697,6 +790,9 @@ int Client::write_cache(const LocalMetaRequest *meta_req) {
wqueue_data_t *wqueue_data = new wqueue_data_t();
wqueue_data->client = this;
wqueue_data->d_ptr = d_ptr;

// streams share the same tasks, the tasks will be merged into kv_map when all streams are
// synced.
wqueue_data->tasks = tasks;
wqueue_data->finished = finished;
wqueue_data->task_id = i;
Expand All @@ -710,7 +806,6 @@ int Client::write_cache(const LocalMetaRequest *meta_req) {
}

send_resp(return_code, NULL, 0);

reset_client_read_state();

return 0;
Expand Down
Loading
Loading