Skip to content

Commit

Permalink
Additional C++ templates for fast sa_decode: additional overload for …
Browse files Browse the repository at this point in the history
…::accum() (facebookresearch#2445)

Summary:
Pull Request resolved: facebookresearch#2445

Add overloads for ::accum() to process 3 vectors per call. It is faster than processing 2 vectors per call in certain cases, at least for the AVX2 code.

Reviewed By: mdouze

Differential Revision: D39176425

fbshipit-source-id: bb39bb1f7a77442d32f20cb29281ec2e2ed2600c
  • Loading branch information
alexanderguzhva authored and facebook-github-bot committed Sep 5, 2022
1 parent abb46ac commit b4924aa
Show file tree
Hide file tree
Showing 8 changed files with 1,429 additions and 1 deletion.
19 changes: 19 additions & 0 deletions faiss/cppcontrib/SaDecodeKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,25 @@
// const float weight1,
// float* const __restrict outputAccum);
// }
// * And one more overload for ::accum that decodes and accumulates
// three vectors per call. Sometimes, it makes sense, at least for AVX2.
// The method signature is the following:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids0,
// const float* const __restrict pqFineCentroids0,
// const uint8_t* const __restrict code0,
// const float weight0,
// const float* const __restrict pqCoarseCentroids1,
// const float* const __restrict pqFineCentroids1,
// const uint8_t* const __restrict code1,
// const float weight1,
// const float* const __restrict pqCoarseCentroids2,
// const float* const __restrict pqFineCentroids2,
// const uint8_t* const __restrict code2,
// const float weight2,
// float* const __restrict outputAccum);
// }
// The provided version is not multithreaded.
//
// Currently, an AVX2+FMA implementation is available. AVX512 version is also
Expand Down
266 changes: 266 additions & 0 deletions faiss/cppcontrib/sa_decode/Level2-avx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,81 @@ struct Index2LevelDecoderImpl<

// clang-format on
}

// process 3 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
const float* const __restrict pqCoarseCentroids2,
const float* const __restrict pqFineCentroids2,
const uint8_t* const __restrict code2,
const float weight2,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
const uint8_t* const __restrict coarse1 = code1;
const uint8_t* const __restrict coarse2 = code2;

// fine quantizer
const uint8_t* const __restrict fine0 = code0 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine1 = code1 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine2 = code2 + N_COARSE_ELEMENTS_BYTES;

// clang-format off

// process chunks, 4 float
// but 8 floats per loop

const intptr_t coarseCode0 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0a = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 0>::get(fine0);
const intptr_t fineCode0b = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 1>::get(fine0);
const intptr_t coarseCode1 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1a = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 0>::get(fine1);
const intptr_t fineCode1b = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 1>::get(fine1);
const intptr_t coarseCode2 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse2);
const intptr_t fineCode2a = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 0>::get(fine2);
const intptr_t fineCode2b = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx + 1>::get(fine2);

__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);

existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 0) * FINE_TABLE_BYTES + fineCode0a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 1) * FINE_TABLE_BYTES + fineCode0b) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);

existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + ((fineCentroidIdx + 0) * FINE_TABLE_BYTES + fineCode1a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids1 + ((fineCentroidIdx + 1) * FINE_TABLE_BYTES + fineCode1b) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);

existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids2 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode2) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids2 + ((fineCentroidIdx + 0) * FINE_TABLE_BYTES + fineCode2a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids2 + ((fineCentroidIdx + 1) * FINE_TABLE_BYTES + fineCode2b) * FINE_SIZE + fineCentroidOffset,
weight2,
existingValue);

_mm256_storeu_ps(outputAccum + CPOS, existingValue);

// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, COARSE_BITS, FINE_BITS, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
pqCoarseCentroids2, pqFineCentroids2, code2, weight2,
outputAccum);

// clang-format on
}
};

template <
Expand Down Expand Up @@ -497,6 +572,74 @@ struct Index2LevelDecoderImpl<

// clang-format on
}

// process 3 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
const float* const __restrict pqCoarseCentroids2,
const float* const __restrict pqFineCentroids2,
const uint8_t* const __restrict code2,
const float weight2,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
const uint8_t* const __restrict coarse1 = code1;
const uint8_t* const __restrict coarse2 = code2;

// fine quantizer
const uint8_t* const __restrict fine0 = code0 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine1 = code1 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine2 = code2 + N_COARSE_ELEMENTS_BYTES;

// clang-format off

// process chunks, 8 float

const intptr_t coarseCode0 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine0);
const intptr_t coarseCode1 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine1);
const intptr_t coarseCode2 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse2);
const intptr_t fineCode2 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine2);

__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);

existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);

existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode1) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);

existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids2 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode2) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids2 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode2) * FINE_SIZE + fineCentroidOffset,
weight2,
existingValue);

_mm256_storeu_ps(outputAccum + CPOS, existingValue);

// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, COARSE_BITS, FINE_BITS, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
pqCoarseCentroids2, pqFineCentroids2, code2, weight2,
outputAccum);

// clang-format on
}
};

template <
Expand Down Expand Up @@ -660,6 +803,74 @@ struct Index2LevelDecoderImpl<

// clang-format on
}

// process 3 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
const float* const __restrict pqCoarseCentroids2,
const float* const __restrict pqFineCentroids2,
const uint8_t* const __restrict code2,
const float weight2,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
const uint8_t* const __restrict coarse1 = code1;
const uint8_t* const __restrict coarse2 = code2;

// fine quantizer
const uint8_t* const __restrict fine0 = code0 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine1 = code1 + N_COARSE_ELEMENTS_BYTES;
const uint8_t* const __restrict fine2 = code2 + N_COARSE_ELEMENTS_BYTES;

// clang-format off

// process chunks, 4 float

const intptr_t coarseCode0 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine0);
const intptr_t coarseCode1 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine1);
const intptr_t coarseCode2 = detail::UintReader<DIM, COARSE_SIZE, COARSE_BITS, coarseCentroidIdx>::get(coarse2);
const intptr_t fineCode2 = detail::UintReader<DIM, FINE_SIZE, FINE_BITS, fineCentroidIdx>::get(fine2);

__m128 existingValue = _mm_loadu_ps(outputAccum + CPOS);

existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);

existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode1) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);

existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids2 + (coarseCentroidIdx * COARSE_TABLE_BYTES + coarseCode2) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids2 + (fineCentroidIdx * FINE_TABLE_BYTES + fineCode2) * FINE_SIZE + fineCentroidOffset,
weight2,
existingValue);

_mm_storeu_ps(outputAccum + CPOS, existingValue);

// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, COARSE_BITS, FINE_BITS, CPOS + 4>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
pqCoarseCentroids2, pqFineCentroids2, code2, weight2,
outputAccum);

// clang-format on
}
};

// This partial specialization is expected to do nothing.
Expand Down Expand Up @@ -712,6 +923,22 @@ struct Index2LevelDecoderImpl<
const float weight1,
float* const __restrict outputAccum) {}

// process 3 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
const float* const __restrict pqCoarseCentroids2,
const float* const __restrict pqFineCentroids2,
const uint8_t* const __restrict code2,
const float weight2,
float* const __restrict outputAccum) {}

// clang-format on
};
} // namespace
Expand Down Expand Up @@ -802,6 +1029,45 @@ struct Index2LevelDecoder {
weight1,
outputAccum);
}

// process 3 samples
// Performs outputAccum += weight0 * decoded(code0) + weight1 *
// decoded(code1) + weight2 * decoded(code2)
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
const float* const __restrict pqCoarseCentroids2,
const float* const __restrict pqFineCentroids2,
const uint8_t* const __restrict code2,
const float weight2,
float* const __restrict outputAccum) {
Index2LevelDecoderImpl<
DIM,
COARSE_SIZE,
FINE_SIZE,
COARSE_BITS,
FINE_BITS,
0>::
accum(pqCoarseCentroids0,
pqFineCentroids0,
code0,
weight0,
pqCoarseCentroids1,
pqFineCentroids1,
code1,
weight1,
pqCoarseCentroids2,
pqFineCentroids2,
code2,
weight2,
outputAccum);
}
};

} // namespace cppcontrib
Expand Down
Loading

0 comments on commit b4924aa

Please sign in to comment.