diff --git a/cpp/src/quantize/bolt.cpp b/cpp/src/quantize/bolt.cpp index 7cc447b7..c19fd443 100644 --- a/cpp/src/quantize/bolt.cpp +++ b/cpp/src/quantize/bolt.cpp @@ -254,13 +254,60 @@ void BoltEncoder::lut_dot(const RowVector& q) { lut_dot(q.data(), static_cast(q.size())); } -template +template +// void _naive_bolt_scan(const uint8_t* codes, const uint8_t* lut_ptr, +void _naive_bolt_scan(const uint8_t* codes, const ColMatrix& luts, + uint16_t* dists_out, int64_t nblocks) +{ + // static constexpr int ncodebooks = 2 * NBytes; + static constexpr int ncentroids = 16; + + auto lut_ptr = luts.data(); + for (int b = 0; b < nblocks; b++) { + auto dist_ptr = dists_out + b * 32; + auto codes_ptr = codes + b * NBytes * 32; + for (int i = 0; i < 32; i++) { + // int dist = dist_ptr[i]; + + int dist_true = 0; + for (int m = 0; m < NBytes; m++) { + uint8_t code = codes_ptr[i + 32 * m]; // TODO uncomment !!!! + // uint8_t code = codes_ptr[i * m + 32]; + + uint8_t low_bits = code & 0x0F; + uint8_t high_bits = (code & 0xF0) >> 4; + + int lut_idx_0 = (2 * m) * ncentroids + low_bits; + int lut_idx_1 = (2 * m + 1) * ncentroids + high_bits; + int d0 = lut_ptr[lut_idx_0]; + int d1 = lut_ptr[lut_idx_1]; + + if (b == 0 && i < 32) { + printf("%3d -> %3d, %3d -> %3d", low_bits, d0, high_bits, d1); + // std::cout << "%d -> %d, %d -> %d\n" + } + // int d0 = luts(low_bits, 2 * m); + // int d1 = luts(high_bits, 2 * m + 1); + + dist_true += d0 + d1; + } + if (b == 0 && i < 32) { + printf(" = %4d\n", dist_true); + } + dist_ptr[i] = dist_true; + } + } +} + +// template +template void query(const float* q, int len, int nbytes, const RowMatrix& centroids, const RowVector& offsets, float scaleby, const RowMatrix& codes, // ColMatrix _centroids, RowMatrix _codes, - int64_t ncodes, ColMatrix& lut_tmp, dist_t* dists) + // int64_t ncodes, ColMatrix& lut_tmp, dist_t* dists) + int64_t ncodes, ColMatrix& lut_tmp, uint16_t* dists) { int ncodebooks = nbytes * 2; assert(nbytes > 0); @@ -281,18 +328,35 @@ void query(const float* q, int len, int nbytes, auto codes_ptr = codes.data(); assert(codes_ptr != nullptr); + // TODO rm + RowMatrix unpacked_codes(32, 4); // just first few rows + for (int i = 0; i < unpacked_codes.rows(); i++) { + unpacked_codes(i, 0) = codes(i, 0) & 0x0F; + unpacked_codes(i, 1) = (codes(i, 0) & 0xF0) >> 4; + unpacked_codes(i, 2) = codes(i, 1) & 0x0F; + unpacked_codes(i, 3) = (codes(i, 1) & 0xF0) >> 4; + } + // create lookup table and then scan with it switch (nbytes) { case 2: -// bolt_lut<2, Reduction>(q, len, _centroids.data(), lut_ptr); + bolt_lut<2, Reduction>(q, len, centroids.data(), offsets.data(), scaleby, + lut_ptr); + // ya, these both match the cpp... + // std::cout << "behold, my lut is:\n"; + // std::cout << lut_tmp.cast(); +// std::cout << "\nmy initial codes are:\n"; +// std::cout << codes.topRows<20>().cast() << "\n"; + std::cout << "\nmy initial unpacked codes are:\n"; + std::cout << unpacked_codes << "\n"; - // TODO rm -// lut(q, len, nbytes, centroids, offsets, scaleby, lut_tmp); + // TODO uncomment + // bolt_scan<2, true>(codes.data(), lut_ptr, dists, nblocks); + + // _naive_bolt_scan<2>(codes.data(), lut_ptr, dists, nblocks); + _naive_bolt_scan<2>(codes.data(), lut_tmp, dists, nblocks); - bolt_lut<2, Reduction>(q, len, centroids.data(), offsets.data(), scaleby, - lut_ptr); - bolt_scan<2, true>(codes.data(), lut_ptr, dists, nblocks); break; case 8: bolt_lut<8, Reduction>(q, len, centroids.data(), offsets.data(), scaleby, @@ -319,6 +383,7 @@ void query(const float* q, int len, int nbytes, } } + // TODO version that writes to argout array to avoid unnecessary copy // TODO allow specifying safe (no overflow) scan template @@ -330,6 +395,7 @@ RowVector query_all(const float* q, int len, int nbytes, ColMatrix& lut_tmp) { RowVector dists(codes.rows()); // need 32B alignment, so can't use stl vector + dists.setZero(); query(q, len, nbytes, centroids, offsets, scaleby, codes, ncodes, lut_tmp, dists.data()); @@ -339,14 +405,14 @@ RowVector query_all(const float* q, int len, int nbytes, std::cout << dists(i) << " "; } std::cout << "]\n"; - + return dists; - + // copy computed distances into a vector to return; would be nice // to write directly into the vector, but can't guarantee 32B alignment // without some hacks // vector ret(ncodes); - + // ret.reserve(_ncodes); // std::memcpy(ret.data(), dists.data(), ncodes * sizeof(uint16_t)); // printf("ncodes, ret.size(), %lld, %lld\n", _ncodes, ret.size()); diff --git a/cpp/src/quantize/bolt.hpp b/cpp/src/quantize/bolt.hpp index 781de80e..b7fd8be7 100644 --- a/cpp/src/quantize/bolt.hpp +++ b/cpp/src/quantize/bolt.hpp @@ -10,6 +10,7 @@ #include #include +#include #include "immintrin.h" // this is what defines all the simd funcs + _MM_SHUFFLE #include // TODO rm after debug @@ -37,7 +38,7 @@ namespace { * first byte of 32 codes at once, then the second byte, etc. * @tparam NBytes Byte length of Bolt encoding for each row */ -template +template void bolt_encode(const float* X, int64_t nrows, int ncols, const float* centroids, uint8_t* out) { @@ -45,65 +46,81 @@ void bolt_encode(const float* X, int64_t nrows, int ncols, static constexpr int packet_width = 8; // objs per simd register static constexpr int nstripes = lut_sz / packet_width; static constexpr int ncodebooks = 2 * NBytes; + static constexpr int block_rows = 32; static_assert(NBytes > 0, "Code length <= 0 is not valid"); + const int64_t nblocks = ceil(nrows / (float)block_rows); const int subvect_len = ncols / ncodebooks; const int trailing_subvect_len = ncols % ncodebooks; assert(trailing_subvect_len == 0); // TODO remove this constraint __m256 accumulators[lut_sz / packet_width]; - for (int64_t n = 0; n < nrows; n++) { // for each row of X - auto x_ptr = X + n * ncols; + // for (int64_t n = 0; n < nrows; n++) { // for each row of X - auto centroids_ptr = centroids; - for (int m = 0; m < ncodebooks; m++) { // for each codebook - for (int i = 0; i < nstripes; i++) { - accumulators[i] = _mm256_setzero_ps(); - } - // compute distances to each of the centroids, which we assume - // are in column major order; this takes 2 packets per col - for (int j = 0; j < subvect_len; j++) { // for each encoded dim - auto x_j_broadcast = _mm256_set1_ps(*x_ptr); - for (int i = 0; i < nstripes; i++) { // for upper and lower 8 - auto centroids_half_col = _mm256_load_ps((float*)centroids_ptr); - centroids_ptr += packet_width; - auto diff = _mm256_sub_ps(x_j_broadcast, centroids_half_col); - accumulators[i] = fma(diff, diff, accumulators[i]); + auto x_ptr = X; + for (int b = 0; b < nblocks; b++) { // for each block + // handle nrows not a multiple of 32 + int limit = (b == (nblocks - 1)) ? (nrows % 32) : block_rows; + for (int n = 0; n < limit; n++) { // for each row in block + // auto x_ptr = X + n * ncols; + + auto centroids_ptr = centroids; + for (int m = 0; m < ncodebooks; m++) { // for each codebook + for (int i = 0; i < nstripes; i++) { + accumulators[i] = _mm256_setzero_ps(); + } + // compute distances to each of the centroids, which we assume + // are in column major order; this takes 2 packets per col + for (int j = 0; j < subvect_len; j++) { // for each encoded dim + auto x_j_broadcast = _mm256_set1_ps(*x_ptr); + for (int i = 0; i < nstripes; i++) { // for upper and lower 8 + auto centroids_half_col = _mm256_load_ps((float*)centroids_ptr); + centroids_ptr += packet_width; + auto diff = _mm256_sub_ps(x_j_broadcast, centroids_half_col); + accumulators[i] = fma(diff, diff, accumulators[i]); + } + x_ptr++; } - x_ptr++; - } - // convert the floats to ints - // XXX distances *must* be >> 0 for this to preserve correctness - auto dists_int32_low = _mm256_cvtps_epi32(accumulators[0]); - auto dists_int32_high = _mm256_cvtps_epi32(accumulators[1]); - - // find the minimum value - auto dists = _mm256_min_epi32(dists_int32_low, dists_int32_high); - auto min_broadcast = broadcast_min(dists); - // int32_t min_val = predux_min(dists); - - // mask where the minimum happens - auto mask_low = _mm256_cmpeq_epi32(dists_int32_low, min_broadcast); - auto mask_high = _mm256_cmpeq_epi32(dists_int32_high, min_broadcast); - - // find first int where mask is set - uint32_t mask0 = _mm256_movemask_epi8(mask_low); // extracts MSBs - uint32_t mask1 = _mm256_movemask_epi8(mask_high); - uint64_t mask = mask0 + (static_cast(mask1) << 32); - uint8_t min_idx = __tzcnt_u64(mask) >> 2; // div by 4 since 4B objs - - if (m % 2) { - // odds -> store in upper 4 bits - out[m / 2] |= min_idx << 4; - } else { - // evens -> store in lower 4 bits; we don't actually need to - // mask because odd iter will clobber upper 4 bits anyway - out[m / 2] = min_idx; - } - } - out += NBytes; - } + // convert the floats to ints + // XXX distances *must* be >> 0 for this to preserve correctness + auto dists_int32_low = _mm256_cvtps_epi32(accumulators[0]); + auto dists_int32_high = _mm256_cvtps_epi32(accumulators[1]); + + // find the minimum value + auto dists = _mm256_min_epi32(dists_int32_low, dists_int32_high); + auto min_broadcast = broadcast_min(dists); + // int32_t min_val = predux_min(dists); + + // mask where the minimum happens + auto mask_low = _mm256_cmpeq_epi32(dists_int32_low, min_broadcast); + auto mask_high = _mm256_cmpeq_epi32(dists_int32_high, min_broadcast); + + // find first int where mask is set + uint32_t mask0 = _mm256_movemask_epi8(mask_low); // extracts MSBs + uint32_t mask1 = _mm256_movemask_epi8(mask_high); + uint64_t mask = mask0 + (static_cast(mask1) << 32); + uint8_t min_idx = __tzcnt_u64(mask) >> 2; // div by 4 since 4B objs + + int out_idx; + if (RowMajor) { + out_idx = m / 2; + } else { + out_idx = block_rows * (m / 2) + n; + } + if (m % 2) { + // odds -> store in upper 4 bits + out[out_idx] |= min_idx << 4; + } else { + // evens -> store in lower 4 bits; we don't actually need to + // mask because odd iter will clobber upper 4 bits anyway + out[out_idx] = min_idx; + } + } // m + if (RowMajor) { out += NBytes; } + } // n within block + if (!RowMajor) { out += NBytes * block_rows; } + } // block } @@ -227,7 +244,7 @@ void bolt_lut(const float* q, int len, const float* centroids, __m256 scaleby_vect = _mm256_set1_ps(scaleby); // TODO uncomment this!!! // __m256 scaleby_vect = _mm256_set1_ps(1.); - std::cout << "cpp scaleby: " << scaleby << "\n"; + // std::cout << "cpp scaleby: " << scaleby << "\n"; for (int m = 0; m < ncodebooks; m++) { // for each codebook for (int i = 0; i < nstripes; i++) { @@ -270,7 +287,7 @@ void bolt_lut(const float* q, int len, const float* centroids, // TODO uncomment this (with the real offsets + scale)!!! // // scale dists and add offsets - std::cout << "cpp offset: " << offsets[m] << "\n"; + // std::cout << "cpp offset: " << offsets[m] << "\n"; auto offset_vect = _mm256_set1_ps(offsets[m]); // auto offset_vect = _mm256_set1_ps(0); auto dist0 = fma(accumulators[0], scaleby_vect, offset_vect); diff --git a/cpp/test/quantize/test_bolt.cpp b/cpp/test/quantize/test_bolt.cpp index 5e543f26..5eae30ef 100644 --- a/cpp/test/quantize/test_bolt.cpp +++ b/cpp/test/quantize/test_bolt.cpp @@ -291,7 +291,6 @@ TEST_CASE("bolt_encode", "[mcq][bolt]") { check_encoding(nrows, enc.codes()); } } - } TEST_CASE("bolt_scan", "[mcq][bolt]") { diff --git a/python/bolt/bolt_api.py b/python/bolt/bolt_api.py index b883989a..d4410de2 100644 --- a/python/bolt/bolt_api.py +++ b/python/bolt/bolt_api.py @@ -69,6 +69,7 @@ def _ensure_num_cols_multiple_of(X, multiple_of): def kmeans(X, k, max_iter=16, init='kmc2'): X = X.astype(np.float32) + np.random.seed(123) # if k is huge, initialize centers with cartesian product of centroids # in two subspaces @@ -299,13 +300,19 @@ def set_data(self, X): col = raw_Xenc[:, in_j] cpp_Xenc[:, out_j] = np.bitwise_and(col, 255 - 15) >> 4 - # # yep, these are the same (well, *almost* always...fp errors?) - # print "python X enc" - # print self.X_enc.shape - # print self.X_enc[:20] + + # XXX are these supposed to be the same? I don't think so... + + + # yep, these are the same (well, *almost* always...fp errors?) + print "python X enc" + print self.X_enc.shape + print self.X_enc[:20] # print "cpp X enc" # print cpp_Xenc.shape # print cpp_Xenc[:20] + # print "raw cpp X_enc" + # print raw_Xenc[:20] self.X_enc += enc_offsets @@ -348,19 +355,24 @@ def dists_sq(self, q): self._encoder.lut_l2(q) lut_cpp = self._encoder.get_lut() - print "py, cpp lut:" # within +/- 1 using naive lut impl in cpp - print lut_py - print lut_cpp + # print "py, cpp lut:" # within +/- 1 using naive lut impl in cpp + # print lut_py + # print lut_cpp # return self._dists(lut) dists_py = self._dists(lut) - dists_cpp = self._encoder.dists_sq(q) + dists_cpp = self._encoder.dists_sq(q)[:len(dists_py)] # strip padding + + # print "py, cpp initial dists:" + # print dists_py[:20] + # print dists_cpp[:20] - print "py, cpp dists:" - print dists_py[:20] - print dists_cpp[:20] + print "py, cpp final dists:" + print dists_py[-20:] + print dists_cpp[-20:] return dists_py + # return dists_cpp def dot_prods(self, q): lut = _fit_pq_lut(q, centroids=self.centroids, @@ -380,6 +392,8 @@ def __init__(self, nbytes=32, reduction=Reductions.SQUARED_EUCLIDEAN): self.reduction = reduction def _preproc(self, X): + # TODO rows of X also needs to have variance >> 1 to avoid + # everything going to 0 when bolt_encode converts to ints in argmin one_d = len(X.shape) == 1 if one_d: X = X.reshape((1, -1)) @@ -396,7 +410,7 @@ def fit(self, X, just_train=False, Q=None): ncentroids = 16 self.DEBUG = False - self.DEBUG = True + # self.DEBUG = True X = self._preproc(X) self._ndims_ = X.shape[1] @@ -462,6 +476,7 @@ def fit(self, X, just_train=False, Q=None): if not just_train: self._encoder_.set_data(X) + self._n = len(X) return self @@ -469,6 +484,7 @@ def set_data(self, X): """set data to actually encode; separate from fit() because fit() could use different training data than what we actully compress""" self._encoder_.set_data(self._preproc(X)) + self._n = len(X) def _set_dot(self): if self.reduction != Reductions.DOT_PRODUCT: @@ -486,11 +502,11 @@ def _set_sq(self): def dot(self, q): self._set_dot() - return self._encoder_.dot_prods(self._preproc(q)) + return self._encoder_.dot_prods(self._preproc(q))[:self._n] def dists_sq(self, q): self._set_sq() - return self._encoder_.dists_sq(self._preproc(q)) + return self._encoder_.dists_sq(self._preproc(q))[:self._n] def knn_dot(self, q, k): self._set_dot() diff --git a/tests/test_encoder.py b/tests/test_encoder.py new file mode 100644 index 00000000..ad33e4e1 --- /dev/null +++ b/tests/test_encoder.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +import numpy as np +from scipy.stats import pearsonr as corr +from sklearn.datasets import load_digits + +import bolt + + +# ================================================================ utils + +def _dists_sq(X, q): + diffs = X - q + return np.sum(diffs * diffs, axis=-1) + + +def _dists_l1(X, q): + diffs = np.abs(X - q) + return np.sum(diffs, axis=-1) + + +def _create_randn_encoder(nbytes=16, Ntrain=100, Ntest=20, D=64): + enc = bolt.Encoder(nbytes=16) + X_train = np.random.randn(Ntrain, D) + X_test = np.random.randn(Ntest, D) + enc.fit(X_train, just_train=True) + enc.set_data(X_test) + return enc + + +# ================================================================ tests + +def test_smoketest(): + """Test that `bolt.Encoder`'s methods don't crash""" + + D = 64 + enc = _create_randn_encoder(D=D) + + Nqueries = 5 + Q = np.random.randn(Nqueries, D) + [enc.dot(q) for q in Q] + [enc.dists_sq(q) for q in Q] + for k in [1, 3]: + [enc.knn_l2(q, k) for q in Q] + [enc.knn_dot(q, k) for q in Q] + + # assert False # yep, this makes it fail, so actually running this test + + +def test_basic(): + # np.set_printoptions(precision=3) + np.set_printoptions(formatter={'float_kind': lambda x: '{:.3f}'.format(x)}) + + X, _ = load_digits(return_X_y=True) + X = X[:, :16] + + num_queries = 20 + Q = X[-num_queries:] + X = X[:-num_queries] + + enc = bolt.Encoder(nbytes=2, reduction=bolt.Reductions.SQUARED_EUCLIDEAN) + enc.fit(X) + + # dot_corrs = np.empty(len(Q)) + l2_corrs = np.empty(len(Q)) + # for i, q in enumerate(Q[:1]): + for i, q in enumerate(Q): + # dots_true = np.dot(X, q) + # dots_bolt = enc.dot(q) + # dot_corrs[i] = corr(dots_true, dots_bolt)[0] + + # print "dots true, dots bolt:" + # print dots_true + # print dots_bolt + + # l2_true = _dists_sq(X, q) + l2_true = _dists_sq(X, q).astype(np.int) + l2_bolt = enc.dists_sq(q) + + # l2_bolt = l2_bolt[:len(l2_true)] # bolt result includes padding + + # l2_corrs[i] = corr(l2_true, l2_bolt)[0] # TODO uncommment + l2_corrs[i] = corr(l2_true[:20], l2_bolt[:20])[0] + + # print "dists sq true, dots bolt:" + # print l2_true[:20] + # print l2_bolt[:20] + + # print "mean dot product correlation, squared l2 dist correlation:" + print "squared l2 dist correlation:" + # print np.mean(dot_corrs[0]) + # print np.mean(l2_corrs[0]) + # print np.mean(dot_corrs) + print np.mean(l2_corrs) + + +if __name__ == '__main__': + test_basic()