Skip to content

Commit

Permalink
[GPU-Plugin] Improved load balancing search (dmlc#2521)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored Jul 16, 2017
1 parent 33ee7d1 commit c85bf98
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 106 deletions.
197 changes: 104 additions & 93 deletions plugin/updater_gpu/src/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ struct Timer {

void reset() { start = ClockT::now(); }
int64_t elapsed() const { return (ClockT::now() - start).count(); }
double elapsedSeconds() const {
return elapsed() * ((double)ClockT::period::num / ClockT::period::den);
}
void printElapsed(std::string label) {
// synchronize_n_devices(n_devices, dList);
printf("%s:\t %lld\n", label.c_str(), elapsed());
printf("%s:\t %fs\n", label.c_str(), elapsedSeconds());
reset();
}
};
Expand Down Expand Up @@ -650,116 +653,124 @@ struct BernoulliRng {

// Load balancing search

template <typename func_t>
class LauncherItr {
public:
int idx;
func_t f;
XGBOOST_DEVICE LauncherItr() : idx(0) {}
XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {}
XGBOOST_DEVICE LauncherItr &operator=(int output) {
f(idx, output);
return *this;
}
};

template <typename func_t>

/**
* \class DiscardLambdaItr
*
* \brief Thrust compatible iterator type - discards algorithm output and
* launches device lambda with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*/

class DiscardLambdaItr {
public:
// Required iterator traits
typedef DiscardLambdaItr self_type; ///< My own type
typedef ptrdiff_t
difference_type; ///< Type to express the result of subtracting
/// one iterator from another
typedef LauncherItr<func_t>
value_type; ///< The type of the element the iterator can point to
typedef value_type *pointer; ///< The type of a pointer to an element the
/// iterator can point to
typedef value_type reference; ///< The type of a reference to an element the
/// iterator can point to
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type iterator_category; ///< The iterator category
private:
difference_type offset;
func_t f;

public:
XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f)
: offset(offset), f(f) {}
template <typename coordinate_t, typename segments_t, typename offset_t>
void FindMergePartitions(int device_idx, coordinate_t *d_tile_coordinates, int num_tiles,
int tile_size, segments_t segments, offset_t num_rows,
offset_t num_elements) {
dh::launch_n(device_idx, num_tiles + 1, [=] __device__(int idx) {
offset_t diagonal = idx * tile_size;
coordinate_t tile_coordinate;
cub::CountingInputIterator<offset_t> nonzero_indices(0);

// Search the merge path
// Cast to signed integer as this function can have negatives
cub::MergePathSearch(static_cast<int64_t>(diagonal), segments + 1,
nonzero_indices, static_cast<int64_t>(num_rows),
static_cast<int64_t>(num_elements), tile_coordinate);

// Output starting offset
d_tile_coordinates[idx] = tile_coordinate;
});
}

XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset + b, f);
}
XGBOOST_DEVICE self_type operator++() {
offset++;
return *this;
}
XGBOOST_DEVICE self_type operator++(int) {
self_type retval = *this;
offset++;
return retval;
}
XGBOOST_DEVICE self_type &operator+=(const int &b) {
offset += b;
return *this;
}
XGBOOST_DEVICE reference operator*() const {
return LauncherItr<func_t>(offset, f);
template <int TILE_SIZE, int ITEMS_PER_THREAD, int BLOCK_THREADS,
typename offset_t, typename coordinate_t, typename func_t,
typename segments_iter>
__global__ void LbsKernel(coordinate_t *d_coordinates,
segments_iter segment_end_offsets, func_t f,
offset_t num_segments) {
int tile = blockIdx.x;
coordinate_t tile_start_coord = d_coordinates[tile];
coordinate_t tile_end_coord = d_coordinates[tile + 1];
int64_t tile_num_rows = tile_end_coord.x - tile_start_coord.x;
int64_t tile_num_elements = tile_end_coord.y - tile_start_coord.y;

cub::CountingInputIterator<offset_t> tile_element_indices(tile_start_coord.y);
coordinate_t thread_start_coord;

typedef typename std::iterator_traits<segments_iter>::value_type segment_t;
__shared__ struct {
segment_t tile_segment_end_offsets[TILE_SIZE + 1];
segment_t output_segment[TILE_SIZE];
} temp_storage;

for (auto item : dh::block_stride_range(int(0), int(tile_num_rows + 1))) {
temp_storage.tile_segment_end_offsets[item] =
segment_end_offsets[min(tile_start_coord.x + item, num_segments - 1)];
}
__syncthreads();

int64_t diag = threadIdx.x * ITEMS_PER_THREAD;

// Cast to signed integer as this function can have negatives
cub::MergePathSearch(diag, // Diagonal
temp_storage.tile_segment_end_offsets, // List A
tile_element_indices, // List B
tile_num_rows, tile_num_elements, thread_start_coord);

coordinate_t thread_current_coord = thread_start_coord;
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) {
if (tile_element_indices[thread_current_coord.y] <
temp_storage.tile_segment_end_offsets[thread_current_coord.x]) {
temp_storage.output_segment[thread_current_coord.y] =
thread_current_coord.x + tile_start_coord.x;
++thread_current_coord.y;
} else {
++thread_current_coord.x;
}
}
__syncthreads();

XGBOOST_DEVICE reference operator[](int idx) {
self_type offset = (*this) + idx;
return *offset;
for (auto item : dh::block_stride_range(int(0), int(tile_num_elements))) {
f(tile_start_coord.y + item, temp_storage.output_segment[item]);
}
};
}

/**
* \fn template <typename func_t, typename segments_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count, thrust::device_ptr<segments_t> segments, int num_segments, func_t f)
* \fn template <typename func_t, typename segments_iter, typename offset_t>
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
* segments_iter segments, offset_t num_segments, func_t f)
*
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
* to be executed on each element. Search 'modern GPU load balancing search for more
* information'.
* \brief Load balancing search function. Reads a CSR type matrix description
* and allows a function to be executed on each element. Search 'modern GPU load
* balancing search' for more information.
*
* \author Rory
* \date 7/9/2017
*
* \tparam segments_t Type of the segments t.
* \tparam func_t Type of the function t.
* \tparam segments_iter Type of the segments iterator.
* \tparam offset_t Type of the offset.
* \tparam segments_t Type of the segments t.
* \param device_idx Zero-based index of the device.
* \param [in,out] temp_memory Temporary memory allocator.
* \param count Number of elements.
* \param segments Device pointed to segments.
* \param segments Device pointer to segments.
* \param num_segments Number of segments.
* \param f Lambda to be executed on matrix elements.
*/

template <typename func_t, typename segments_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, int count,
thrust::device_ptr<segments_t> segments, int num_segments,
func_t f) {
safe_cuda(cudaSetDevice(device_idx));
auto counting = thrust::make_counting_iterator(0);

auto f_wrapper = [=] __device__(int idx, int upper_bound) {
f(idx, upper_bound - 1);
};

DiscardLambdaItr<decltype(f_wrapper)> itr(f_wrapper);

thrust::upper_bound(thrust::cuda::par(*temp_memory), segments,
segments + num_segments, counting, counting + count, itr);
template <typename func_t, typename segments_iter, typename offset_t>
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
segments_iter segments, offset_t num_segments, func_t f) {
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
dh::safe_cuda(cudaSetDevice(device_idx));
const int BLOCK_THREADS = 256;
const int ITEMS_PER_THREAD = 1;
const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD;
int num_tiles = dh::div_round_up(count + num_segments, BLOCK_THREADS);

temp_memory->LazyAllocate(sizeof(coordinate_t) * (num_tiles + 1));
coordinate_t *tmp_tile_coordinates =
reinterpret_cast<coordinate_t *>(temp_memory->d_temp_storage);

FindMergePartitions(device_idx, tmp_tile_coordinates, num_tiles, BLOCK_THREADS, segments,
num_segments, count);

LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t>
<<<num_tiles, BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
num_segments);
}

} // namespace dh
76 changes: 63 additions & 13 deletions plugin/updater_gpu/test/cpp/test_device_helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,72 @@
#include "../../src/device_helpers.cuh"
#include "gtest/gtest.h"

static const std::vector<int> gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7};
static const std::vector<int> row_ptr = {0, 3, 6, 8, 10};
static const std::vector<int> lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3};

thrust::device_vector<int> test_lbs() {
thrust::device_vector<int> device_gidx = gidx;
thrust::device_vector<int> device_row_ptr = row_ptr;
thrust::device_vector<int> device_output_row(gidx.size(), 0);
auto d_output_row = device_output_row.data();
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
thrust::host_vector<int> *row_ptr,
thrust::host_vector<xgboost::bst_uint> *rows) {
row_ptr->resize(num_rows + 1);
int sum = 0;
for (int i = 0; i <= num_rows; i++) {
(*row_ptr)[i] = sum;
sum += rand() % max_row_size; // NOLINT

if (i < num_rows) {
for (int j = (*row_ptr)[i]; j < sum; j++) {
(*rows).push_back(i);
}
}
}
}

void SpeedTest() {
int num_rows = 1000000;
int max_row_size = 100;
dh::CubMemory temp_memory;
thrust::host_vector<int> h_row_ptr;
thrust::host_vector<xgboost::bst_uint> h_rows;
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
thrust::device_vector<int> row_ptr = h_row_ptr;
thrust::device_vector<int> output_row(h_rows.size());
auto d_output_row = output_row.data();

dh::Timer t;
dh::TransformLbs(
0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1,
[=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; });
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1,
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });

dh::safe_cuda(cudaDeviceSynchronize());
return device_output_row;
double time = t.elapsedSeconds();
const int mb_size = 1048576;
size_t size = (sizeof(int) * h_rows.size()) / mb_size;
printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time,
size / time);
}

TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); }
void TestLbs() {
srand(17);
dh::CubMemory temp_memory;

std::vector<int> test_rows = {4, 100, 1000};
std::vector<int> test_max_row_sizes = {4, 100, 1300};

for (auto num_rows : test_rows) {
for (auto max_row_size : test_max_row_sizes) {
thrust::host_vector<int> h_row_ptr;
thrust::host_vector<xgboost::bst_uint> h_rows;
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
thrust::device_vector<size_t> row_ptr = h_row_ptr;
thrust::device_vector<int> output_row(h_rows.size());
auto d_output_row = output_row.data();

dh::TransformLbs(0, &temp_memory, h_rows.size(), dh::raw(row_ptr),
row_ptr.size() - 1,
[=] __device__(size_t idx, size_t ridx) {
d_output_row[idx] = ridx;
});

dh::safe_cuda(cudaDeviceSynchronize());
ASSERT_TRUE(h_rows == output_row);
}
}
}
TEST(cub_lbs, Test) { TestLbs(); }

0 comments on commit c85bf98

Please sign in to comment.