Skip to content

Commit

Permalink
get python bolt api working for l2 (code still ugly)
Browse files Browse the repository at this point in the history
  • Loading branch information
dblalock committed May 3, 2017
1 parent e30d6c5 commit 3a5e778
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 78 deletions.
88 changes: 77 additions & 11 deletions cpp/src/quantize/bolt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,60 @@ void BoltEncoder::lut_dot(const RowVector<float>& q) {
lut_dot(q.data(), static_cast<int>(q.size()));
}

template<int Reduction=Reductions::DistL2, class dist_t>
template<int NBytes>
// void _naive_bolt_scan(const uint8_t* codes, const uint8_t* lut_ptr,
void _naive_bolt_scan(const uint8_t* codes, const ColMatrix<uint8_t>& luts,
uint16_t* dists_out, int64_t nblocks)
{
// static constexpr int ncodebooks = 2 * NBytes;
static constexpr int ncentroids = 16;

auto lut_ptr = luts.data();
for (int b = 0; b < nblocks; b++) {
auto dist_ptr = dists_out + b * 32;
auto codes_ptr = codes + b * NBytes * 32;
for (int i = 0; i < 32; i++) {
// int dist = dist_ptr[i];

int dist_true = 0;
for (int m = 0; m < NBytes; m++) {
uint8_t code = codes_ptr[i + 32 * m]; // TODO uncomment !!!!
// uint8_t code = codes_ptr[i * m + 32];

uint8_t low_bits = code & 0x0F;
uint8_t high_bits = (code & 0xF0) >> 4;

int lut_idx_0 = (2 * m) * ncentroids + low_bits;
int lut_idx_1 = (2 * m + 1) * ncentroids + high_bits;
int d0 = lut_ptr[lut_idx_0];
int d1 = lut_ptr[lut_idx_1];

if (b == 0 && i < 32) {
printf("%3d -> %3d, %3d -> %3d", low_bits, d0, high_bits, d1);
// std::cout << "%d -> %d, %d -> %d\n"
}
// int d0 = luts(low_bits, 2 * m);
// int d1 = luts(high_bits, 2 * m + 1);

dist_true += d0 + d1;
}
if (b == 0 && i < 32) {
printf(" = %4d\n", dist_true);
}
dist_ptr[i] = dist_true;
}
}
}

// template<int Reduction=Reductions::DistL2, class dist_t>
template<int Reduction=Reductions::DistL2>
void query(const float* q, int len, int nbytes,
const RowMatrix<float>& centroids,
const RowVector<float>& offsets, float scaleby,
const RowMatrix<uint8_t>& codes,
// ColMatrix<float> _centroids, RowMatrix<uint8_t> _codes,
int64_t ncodes, ColMatrix<uint8_t>& lut_tmp, dist_t* dists)
// int64_t ncodes, ColMatrix<uint8_t>& lut_tmp, dist_t* dists)
int64_t ncodes, ColMatrix<uint8_t>& lut_tmp, uint16_t* dists)
{
int ncodebooks = nbytes * 2;
assert(nbytes > 0);
Expand All @@ -281,18 +328,35 @@ void query(const float* q, int len, int nbytes,
auto codes_ptr = codes.data();
assert(codes_ptr != nullptr);

// TODO rm
RowMatrix<uint16_t> unpacked_codes(32, 4); // just first few rows
for (int i = 0; i < unpacked_codes.rows(); i++) {
unpacked_codes(i, 0) = codes(i, 0) & 0x0F;
unpacked_codes(i, 1) = (codes(i, 0) & 0xF0) >> 4;
unpacked_codes(i, 2) = codes(i, 1) & 0x0F;
unpacked_codes(i, 3) = (codes(i, 1) & 0xF0) >> 4;
}

// create lookup table and then scan with it
switch (nbytes) {
case 2:
// bolt_lut<2, Reduction>(q, len, _centroids.data(), lut_ptr);
bolt_lut<2, Reduction>(q, len, centroids.data(), offsets.data(), scaleby,
lut_ptr);

// ya, these both match the cpp...
// std::cout << "behold, my lut is:\n";
// std::cout << lut_tmp.cast<uint16_t>();
// std::cout << "\nmy initial codes are:\n";
// std::cout << codes.topRows<20>().cast<uint16_t>() << "\n";
std::cout << "\nmy initial unpacked codes are:\n";
std::cout << unpacked_codes << "\n";

// TODO rm
// lut<Reduction>(q, len, nbytes, centroids, offsets, scaleby, lut_tmp);
// TODO uncomment
// bolt_scan<2, true>(codes.data(), lut_ptr, dists, nblocks);

// _naive_bolt_scan<2>(codes.data(), lut_ptr, dists, nblocks);
_naive_bolt_scan<2>(codes.data(), lut_tmp, dists, nblocks);

bolt_lut<2, Reduction>(q, len, centroids.data(), offsets.data(), scaleby,
lut_ptr);
bolt_scan<2, true>(codes.data(), lut_ptr, dists, nblocks);
break;
case 8:
bolt_lut<8, Reduction>(q, len, centroids.data(), offsets.data(), scaleby,
Expand All @@ -319,6 +383,7 @@ void query(const float* q, int len, int nbytes,
}
}


// TODO version that writes to argout array to avoid unnecessary copy
// TODO allow specifying safe (no overflow) scan
template<int Reduction=Reductions::DistL2>
Expand All @@ -330,6 +395,7 @@ RowVector<uint16_t> query_all(const float* q, int len, int nbytes,
ColMatrix<uint8_t>& lut_tmp)
{
RowVector<uint16_t> dists(codes.rows()); // need 32B alignment, so can't use stl vector
dists.setZero();
query(q, len, nbytes, centroids, offsets, scaleby, codes, ncodes,
lut_tmp, dists.data());

Expand All @@ -339,14 +405,14 @@ RowVector<uint16_t> query_all(const float* q, int len, int nbytes,
std::cout << dists(i) << " ";
}
std::cout << "]\n";

return dists;

// copy computed distances into a vector to return; would be nice
// to write directly into the vector, but can't guarantee 32B alignment
// without some hacks
// vector<uint16_t> ret(ncodes);

// ret.reserve(_ncodes);
// std::memcpy(ret.data(), dists.data(), ncodes * sizeof(uint16_t));
// printf("ncodes, ret.size(), %lld, %lld\n", _ncodes, ret.size());
Expand Down
121 changes: 69 additions & 52 deletions cpp/src/quantize/bolt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <assert.h>
#include <sys/types.h>
#include <math.h>
#include "immintrin.h" // this is what defines all the simd funcs + _MM_SHUFFLE

#include <iostream> // TODO rm after debug
Expand Down Expand Up @@ -37,73 +38,89 @@ namespace {
* first byte of 32 codes at once, then the second byte, etc.
* @tparam NBytes Byte length of Bolt encoding for each row
*/
template<int NBytes>
template<int NBytes, bool RowMajor=false>
void bolt_encode(const float* X, int64_t nrows, int ncols,
const float* centroids, uint8_t* out)
{
static constexpr int lut_sz = 16;
static constexpr int packet_width = 8; // objs per simd register
static constexpr int nstripes = lut_sz / packet_width;
static constexpr int ncodebooks = 2 * NBytes;
static constexpr int block_rows = 32;
static_assert(NBytes > 0, "Code length <= 0 is not valid");
const int64_t nblocks = ceil(nrows / (float)block_rows);
const int subvect_len = ncols / ncodebooks;
const int trailing_subvect_len = ncols % ncodebooks;
assert(trailing_subvect_len == 0); // TODO remove this constraint

__m256 accumulators[lut_sz / packet_width];

for (int64_t n = 0; n < nrows; n++) { // for each row of X
auto x_ptr = X + n * ncols;
// for (int64_t n = 0; n < nrows; n++) { // for each row of X

auto centroids_ptr = centroids;
for (int m = 0; m < ncodebooks; m++) { // for each codebook
for (int i = 0; i < nstripes; i++) {
accumulators[i] = _mm256_setzero_ps();
}
// compute distances to each of the centroids, which we assume
// are in column major order; this takes 2 packets per col
for (int j = 0; j < subvect_len; j++) { // for each encoded dim
auto x_j_broadcast = _mm256_set1_ps(*x_ptr);
for (int i = 0; i < nstripes; i++) { // for upper and lower 8
auto centroids_half_col = _mm256_load_ps((float*)centroids_ptr);
centroids_ptr += packet_width;
auto diff = _mm256_sub_ps(x_j_broadcast, centroids_half_col);
accumulators[i] = fma(diff, diff, accumulators[i]);
auto x_ptr = X;
for (int b = 0; b < nblocks; b++) { // for each block
// handle nrows not a multiple of 32
int limit = (b == (nblocks - 1)) ? (nrows % 32) : block_rows;
for (int n = 0; n < limit; n++) { // for each row in block
// auto x_ptr = X + n * ncols;

auto centroids_ptr = centroids;
for (int m = 0; m < ncodebooks; m++) { // for each codebook
for (int i = 0; i < nstripes; i++) {
accumulators[i] = _mm256_setzero_ps();
}
// compute distances to each of the centroids, which we assume
// are in column major order; this takes 2 packets per col
for (int j = 0; j < subvect_len; j++) { // for each encoded dim
auto x_j_broadcast = _mm256_set1_ps(*x_ptr);
for (int i = 0; i < nstripes; i++) { // for upper and lower 8
auto centroids_half_col = _mm256_load_ps((float*)centroids_ptr);
centroids_ptr += packet_width;
auto diff = _mm256_sub_ps(x_j_broadcast, centroids_half_col);
accumulators[i] = fma(diff, diff, accumulators[i]);
}
x_ptr++;
}
x_ptr++;
}

// convert the floats to ints
// XXX distances *must* be >> 0 for this to preserve correctness
auto dists_int32_low = _mm256_cvtps_epi32(accumulators[0]);
auto dists_int32_high = _mm256_cvtps_epi32(accumulators[1]);

// find the minimum value
auto dists = _mm256_min_epi32(dists_int32_low, dists_int32_high);
auto min_broadcast = broadcast_min(dists);
// int32_t min_val = predux_min(dists);

// mask where the minimum happens
auto mask_low = _mm256_cmpeq_epi32(dists_int32_low, min_broadcast);
auto mask_high = _mm256_cmpeq_epi32(dists_int32_high, min_broadcast);

// find first int where mask is set
uint32_t mask0 = _mm256_movemask_epi8(mask_low); // extracts MSBs
uint32_t mask1 = _mm256_movemask_epi8(mask_high);
uint64_t mask = mask0 + (static_cast<uint64_t>(mask1) << 32);
uint8_t min_idx = __tzcnt_u64(mask) >> 2; // div by 4 since 4B objs

if (m % 2) {
// odds -> store in upper 4 bits
out[m / 2] |= min_idx << 4;
} else {
// evens -> store in lower 4 bits; we don't actually need to
// mask because odd iter will clobber upper 4 bits anyway
out[m / 2] = min_idx;
}
}
out += NBytes;
}
// convert the floats to ints
// XXX distances *must* be >> 0 for this to preserve correctness
auto dists_int32_low = _mm256_cvtps_epi32(accumulators[0]);
auto dists_int32_high = _mm256_cvtps_epi32(accumulators[1]);

// find the minimum value
auto dists = _mm256_min_epi32(dists_int32_low, dists_int32_high);
auto min_broadcast = broadcast_min(dists);
// int32_t min_val = predux_min(dists);

// mask where the minimum happens
auto mask_low = _mm256_cmpeq_epi32(dists_int32_low, min_broadcast);
auto mask_high = _mm256_cmpeq_epi32(dists_int32_high, min_broadcast);

// find first int where mask is set
uint32_t mask0 = _mm256_movemask_epi8(mask_low); // extracts MSBs
uint32_t mask1 = _mm256_movemask_epi8(mask_high);
uint64_t mask = mask0 + (static_cast<uint64_t>(mask1) << 32);
uint8_t min_idx = __tzcnt_u64(mask) >> 2; // div by 4 since 4B objs

int out_idx;
if (RowMajor) {
out_idx = m / 2;
} else {
out_idx = block_rows * (m / 2) + n;
}
if (m % 2) {
// odds -> store in upper 4 bits
out[out_idx] |= min_idx << 4;
} else {
// evens -> store in lower 4 bits; we don't actually need to
// mask because odd iter will clobber upper 4 bits anyway
out[out_idx] = min_idx;
}
} // m
if (RowMajor) { out += NBytes; }
} // n within block
if (!RowMajor) { out += NBytes * block_rows; }
} // block
}


Expand Down Expand Up @@ -227,7 +244,7 @@ void bolt_lut(const float* q, int len, const float* centroids,
__m256 scaleby_vect = _mm256_set1_ps(scaleby); // TODO uncomment this!!!
// __m256 scaleby_vect = _mm256_set1_ps(1.);

std::cout << "cpp scaleby: " << scaleby << "\n";
// std::cout << "cpp scaleby: " << scaleby << "\n";

for (int m = 0; m < ncodebooks; m++) { // for each codebook
for (int i = 0; i < nstripes; i++) {
Expand Down Expand Up @@ -270,7 +287,7 @@ void bolt_lut(const float* q, int len, const float* centroids,

// TODO uncomment this (with the real offsets + scale)!!!
// // scale dists and add offsets
std::cout << "cpp offset: " << offsets[m] << "\n";
// std::cout << "cpp offset: " << offsets[m] << "\n";
auto offset_vect = _mm256_set1_ps(offsets[m]);
// auto offset_vect = _mm256_set1_ps(0);
auto dist0 = fma(accumulators[0], scaleby_vect, offset_vect);
Expand Down
1 change: 0 additions & 1 deletion cpp/test/quantize/test_bolt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ TEST_CASE("bolt_encode", "[mcq][bolt]") {
check_encoding(nrows, enc.codes());
}
}

}

TEST_CASE("bolt_scan", "[mcq][bolt]") {
Expand Down
Loading

0 comments on commit 3a5e778

Please sign in to comment.