Skip to content

Commit

Permalink
Merge pull request #42 from tfhe/rnx-and-zn-apis
Browse files Browse the repository at this point in the history
Rnx and zn apis
  • Loading branch information
ngama75 authored Aug 21, 2024
2 parents e7fc5ef + 42a2343 commit dbca776
Show file tree
Hide file tree
Showing 39 changed files with 4,874 additions and 1 deletion.
28 changes: 27 additions & 1 deletion spqlios/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,22 @@ set(SRCS_GENERIC
arithmetic/vec_znx_big.c
arithmetic/znx_small.c
arithmetic/module_api.c
arithmetic/zn_vmp_int8_ref.c
arithmetic/zn_vmp_int16_ref.c
arithmetic/zn_vmp_int32_ref.c
arithmetic/zn_vmp_ref.c
arithmetic/zn_api.c
arithmetic/zn_conversions_ref.c
arithmetic/zn_approxdecomp_ref.c
arithmetic/vec_rnx_api.c
arithmetic/vec_rnx_conversions_ref.c
arithmetic/vec_rnx_svp_ref.c
reim/reim_execute.c
cplx/cplx_execute.c
reim4/reim4_execute.c
arithmetic/vec_rnx_arithmetic.c
arithmetic/vec_rnx_approxdecomp_ref.c
arithmetic/vec_rnx_vmp_ref.c
)
# C or assembly source files compiled only on x86 targets
set(SRCS_X86
Expand Down Expand Up @@ -95,9 +108,16 @@ set(SRCS_AVX2
arithmetic/vec_znx_avx.c
coeffs/coeffs_arithmetic_avx.c
arithmetic/vec_znx_dft_avx2.c
arithmetic/zn_vmp_int8_avx.c
arithmetic/zn_vmp_int16_avx.c
arithmetic/zn_vmp_int32_avx.c
q120/q120_arithmetic_avx2.c
q120/q120_ntt_avx2.c
)
arithmetic/vec_rnx_arithmetic_avx.c
arithmetic/vec_rnx_approxdecomp_avx.c
arithmetic/vec_rnx_vmp_avx.c

)
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")

# C source files on float128 via libquadmath on x86 targets targets
Expand All @@ -110,6 +130,8 @@ set(SRCS_F128
set(HEADERSPUBLIC
commons.h
arithmetic/vec_znx_arithmetic.h
arithmetic/vec_rnx_arithmetic.h
arithmetic/zn_arithmetic.h
cplx/cplx_fft.h
reim/reim_fft.h
q120/q120_common.h
Expand All @@ -131,6 +153,10 @@ set(HEADERSPRIVATE
q120/q120_arithmetic_private.h
q120/q120_ntt_private.h
arithmetic/vec_znx_arithmetic.h
arithmetic/vec_rnx_arithmetic_private.h
arithmetic/vec_rnx_arithmetic_plugin.h
arithmetic/zn_arithmetic_private.h
arithmetic/zn_arithmetic_plugin.h
coeffs/coeffs_arithmetic.h
reim/reim_fft_core_template.h
)
Expand Down
318 changes: 318 additions & 0 deletions spqlios/arithmetic/vec_rnx_api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
#include <string.h>

#include "vec_rnx_arithmetic_private.h"

void fft64_init_rnx_module_precomp(MOD_RNX* module) {
// Add here initialization of items that are in the precomp
const uint64_t m = module->m;
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
}

void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
// Add here deleters for items that are in the precomp
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
}

void fft64_init_rnx_module_vtable(MOD_RNX* module) {
// Add function pointers here
module->vtable.vec_rnx_add = vec_rnx_add_ref;
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;

// Add optimized function pointers here
if (CPU_SUPPORTS("avx")) {
module->vtable.vec_rnx_add = vec_rnx_add_avx;
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
}
}

void init_rnx_module_info(MOD_RNX* module, //
uint64_t n, RNX_MODULE_TYPE mtype) {
memset(module, 0, sizeof(MOD_RNX));
module->n = n;
module->m = n >> 1;
module->mtype = mtype;
switch (mtype) {
case FFT64:
fft64_init_rnx_module_precomp(module);
fft64_init_rnx_module_vtable(module);
break;
default:
NOT_SUPPORTED(); // unknown mtype
}
}

void finalize_rnx_module_info(MOD_RNX* module) {
if (module->custom) module->custom_deleter(module->custom);
switch (module->mtype) {
case FFT64:
fft64_finalize_rnx_module_precomp(module);
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
break;
default:
NOT_SUPPORTED(); // unknown mtype
}
}

EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
init_rnx_module_info(res, nn, mtype);
return res;
}

EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
finalize_rnx_module_info(module_info);
free(module_info);
}

EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }

/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
uint64_t nrows, uint64_t ncols) { // dimensions
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
}
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }

//////////////// wrappers //////////////////

/** @brief sets res = a + b */
EXPORT void vec_rnx_add( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl, // a
const double* b, uint64_t b_size, uint64_t b_sl // b
) {
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
}

/** @brief sets res = 0 */
EXPORT void vec_rnx_zero( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl // res
) {
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
}

/** @brief sets res = a */
EXPORT void vec_rnx_copy( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = -a */
EXPORT void vec_rnx_negate( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = a - b */
EXPORT void vec_rnx_sub( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl, // a
const double* b, uint64_t b_size, uint64_t b_sl // b
) {
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
}

/** @brief sets res = a . X^p */
EXPORT void vec_rnx_rotate( //
const MOD_RNX* module, // N
const int64_t p, // rotation value
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = a(X^p) */
EXPORT void vec_rnx_automorphism( //
const MOD_RNX* module, // N
int64_t p, // X -> X^p
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_mul_xp_minus_one( //
const MOD_RNX* module, // N
const int64_t p, // rotation value
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
}
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
uint64_t nrows, uint64_t ncols) { // dimensions
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
}

/** @brief prepares a vmp matrix (contiguous row-major version) */
EXPORT void rnx_vmp_prepare_contiguous( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* a, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
}

/** @brief number of scratch bytes necessary to prepare a matrix */
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) {
return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module);
}

/** @brief applies a vmp product res = a x pmat */
EXPORT void rnx_vmp_apply_tmp_a( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
}

EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
const MOD_RNX* module, // N
uint64_t res_size, // res size
uint64_t a_size, // a size
uint64_t nrows, uint64_t ncols // prep matrix dims
) {
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
}

/** @brief minimal size of the tmp_space */
EXPORT void rnx_vmp_apply_dft_to_dft( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
) {
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
tmp_space);
}

/** @brief minimal size of the tmp_space */
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
const MOD_RNX* module, // N
uint64_t res_size, // res
uint64_t a_size, // a
uint64_t nrows, uint64_t ncols // prep matrix
) {
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
}

EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }

EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
RNX_SVP_PPOL* ppol, // output
const double* pol // a
) {
module->vtable.rnx_svp_prepare(module, ppol, pol);
}

EXPORT void rnx_svp_apply( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // output
const RNX_SVP_PPOL* ppol, // prepared pol
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.rnx_svp_apply(module, // N
res, res_size, res_sl, // output
ppol, // prepared pol
a, a_size, a_sl);
}

EXPORT void rnx_approxdecomp_from_tnxdbl( //
const MOD_RNX* module, // N
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a) { // a
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
}

EXPORT void vec_rnx_to_znx32( //
const MOD_RNX* module, // N
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_from_znx32( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_to_tnx32( //
const MOD_RNX* module, // N
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_from_tnx32( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_to_tnxdbl( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
}
Loading

0 comments on commit dbca776

Please sign in to comment.