Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Integer Rings in RNS Representation #787

Merged
merged 39 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8d6ae85
initial integer RNS, labrador RNS and test
yshekel Feb 23, 2025
24ad59d
compare rns and direct multiplication
yshekel Feb 24, 2025
0118ecc
implement additional RNS ops
yshekel Feb 24, 2025
307a1b3
implement has_inverse() for ZqRns
yshekel Feb 24, 2025
b164890
instantiate dispatchers for rns type
yshekel Feb 24, 2025
fc04e98
register cpu vec ops for scalar_rns_t
yshekel Feb 24, 2025
f1a163c
remove unecessary ring tests
yshekel Feb 24, 2025
577364f
temporary: disable NTT for labrador until implemented for RNS
yshekel Feb 24, 2025
232f6c2
implement ntt for rns type
yshekel Feb 24, 2025
acd3c48
shared memory support for ring rns type
yshekel Feb 24, 2025
b5ee22e
inline rns ops
yshekel Feb 24, 2025
764ae36
greyhound WIP
yshekel Feb 26, 2025
7ad80d1
bench
yshekel Feb 26, 2025
8cab57d
direct to rns conversion
yshekel Feb 27, 2025
b68a6be
temporary crt implementation for labrador
yshekel Feb 27, 2025
583c14d
generic crt implementation for conversion
yshekel Feb 27, 2025
2a85507
C++ APIs for rns<-->direct conversions
svpolonsky Mar 2, 2025
b443421
test rns conversions
svpolonsky Mar 2, 2025
e985ffd
refined cpu rns conversion
svpolonsky Mar 2, 2025
c3e66ee
removed greyhound stuff
svpolonsky Mar 2, 2025
55347a9
refined rns type
svpolonsky Mar 2, 2025
d6ab1ca
fixed inverse to return zero for non invertible elements
svpolonsky Mar 2, 2025
bab89af
Example/sumcheck (#781)
krakhit Feb 24, 2025
6aa125b
Handle rou errors on the Rust side (#784)
nonam3e Feb 24, 2025
c33715a
[Fix] ModArith neg when b is u32 (#785)
LeonHibnik Feb 25, 2025
a4eaa7e
Idan/improve sumcheck docs (#778)
idanfr-ingo Feb 25, 2025
f2e3897
[FIX] Make sumcheck tests able to run without cuda backend
idanfr-ingo Feb 25, 2025
b850ba0
Fix/pow misaligned (#791)
nonam3e Feb 26, 2025
7894ec0
Fix poseidon2 example to support any backend (#792)
yshekel Feb 26, 2025
572d5f5
Parallelize-sumcheck (#789)
mickeyasa Feb 27, 2025
094b838
docs: fix broken links in libraries and poseidon sections (#752)
youyyytrok Mar 2, 2025
cf725cc
Update UnivariatePolynomial trait with arithmetics for generics (#793)
krakhit Mar 2, 2025
dbed012
Merge remote-tracking branch 'origin/main' into yshekel/ring_rns
yshekel Mar 2, 2025
cb20581
reverted unrelated changes for this PR
yshekel Mar 3, 2025
3c751c1
refined rings config
yshekel Mar 3, 2025
7f7baf7
cleanup ntt test
yshekel Mar 3, 2025
30a2c86
compilation issue
yshekel Mar 3, 2025
07765ee
rns print element wise
yshekel Mar 3, 2025
6189683
Merge remote-tracking branch 'origin/main' into yshekel/ring_rns
yshekel Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions icicle/backend/cpu/src/field/cpu_ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
using namespace field_config;
using namespace icicle;

eIcicleError
cpu_ntt_init_domain(const Device& device, const scalar_t& primitive_root, const NTTInitDomainConfig& config)
template <typename S = scalar_t>
eIcicleError cpu_ntt_init_domain(const Device& device, const S& primitive_root, const NTTInitDomainConfig& config)
{
auto err = ntt_cpu::CpuNttDomain<scalar_t>::cpu_ntt_init_domain(device, primitive_root, config);
auto err = ntt_cpu::CpuNttDomain<S>::cpu_ntt_init_domain(device, primitive_root, config);
return err;
}

template <typename S = scalar_t>
eIcicleError cpu_ntt_release_domain(const Device& device, const S& dummy)
{
auto err = ntt_cpu::CpuNttDomain<scalar_t>::cpu_ntt_release_domain(device);
auto err = ntt_cpu::CpuNttDomain<S>::cpu_ntt_release_domain(device);
return err;
}

template <typename S = scalar_t>
eIcicleError cpu_get_root_of_unity_from_domain(const Device& device, uint64_t logn, S* rou)
{
auto err = ntt_cpu::CpuNttDomain<scalar_t>::get_root_of_unity_from_domain(device, logn, rou);
auto err = ntt_cpu::CpuNttDomain<S>::get_root_of_unity_from_domain(device, logn, rou);
return err;
}

Expand All @@ -32,11 +32,18 @@ cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, const N
return err;
}

REGISTER_NTT_INIT_DOMAIN_BACKEND("CPU", (cpu_ntt_init_domain));
REGISTER_NTT_INIT_DOMAIN_BACKEND("CPU", (cpu_ntt_init_domain<scalar_t>));
REGISTER_NTT_RELEASE_DOMAIN_BACKEND("CPU", cpu_ntt_release_domain<scalar_t>);
REGISTER_NTT_GET_ROU_FROM_DOMAIN_BACKEND("CPU", cpu_get_root_of_unity_from_domain<scalar_t>);
REGISTER_NTT_BACKEND("CPU", (cpu_ntt<scalar_t, scalar_t>));

#ifdef EXT_FIELD
REGISTER_NTT_EXT_FIELD_BACKEND("CPU", (cpu_ntt<scalar_t, extension_t>));
#endif // EXT_FIELD

#ifdef RING
REGISTER_NTT_INIT_DOMAIN_RING_RNS_BACKEND("CPU", (cpu_ntt_init_domain<scalar_rns_t>));
REGISTER_NTT_RELEASE_DOMAIN_RING_RNS_BACKEND("CPU", cpu_ntt_release_domain<scalar_rns_t>);
REGISTER_NTT_GET_ROU_FROM_DOMAIN_RING_RNS_BACKEND("CPU", cpu_get_root_of_unity_from_domain<scalar_rns_t>);
REGISTER_NTT_RING_RNS_BACKEND("CPU", (cpu_ntt<scalar_rns_t, scalar_rns_t>));
#endif // RING
64 changes: 54 additions & 10 deletions icicle/backend/cpu/src/field/cpu_vec_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,6 @@ eIcicleError cpu_matrix_transpose(
}

REGISTER_MATRIX_TRANSPOSE_BACKEND("CPU", cpu_matrix_transpose<scalar_t>);
#ifdef EXT_FIELD
REGISTER_MATRIX_TRANSPOSE_EXT_FIELD_BACKEND("CPU", cpu_matrix_transpose<extension_t>);
#endif // EXT_FIELD

/*********************************** BIT REVERSE ***********************************/
template <typename T>
Expand Down Expand Up @@ -765,9 +762,6 @@ cpu_bit_reverse(const Device& device, const T* vec_in, uint64_t size, const VecO
}

REGISTER_BIT_REVERSE_BACKEND("CPU", cpu_bit_reverse<scalar_t>);
#ifdef EXT_FIELD
REGISTER_BIT_REVERSE_EXT_FIELD_BACKEND("CPU", cpu_bit_reverse<extension_t>);
#endif // EXT_FIELD

/*********************************** SLICE ***********************************/

Expand Down Expand Up @@ -802,9 +796,6 @@ eIcicleError cpu_slice(
}

REGISTER_SLICE_BACKEND("CPU", cpu_slice<scalar_t>);
#ifdef EXT_FIELD
REGISTER_SLICE_EXT_FIELD_BACKEND("CPU", cpu_slice<extension_t>);
#endif // EXT_FIELD

/*********************************** Highest non-zero idx ***********************************/
template <typename T>
Expand Down Expand Up @@ -993,6 +984,9 @@ eIcicleError cpu_poly_divide(
REGISTER_POLYNOMIAL_DIVISION("CPU", cpu_poly_divide<scalar_t>);

#ifdef EXT_FIELD
REGISTER_MATRIX_TRANSPOSE_EXT_FIELD_BACKEND("CPU", cpu_matrix_transpose<extension_t>);
REGISTER_BIT_REVERSE_EXT_FIELD_BACKEND("CPU", cpu_bit_reverse<extension_t>);
REGISTER_SLICE_EXT_FIELD_BACKEND("CPU", cpu_slice<extension_t>);
REGISTER_VECTOR_ADD_EXT_FIELD_BACKEND("CPU", cpu_vector_add<extension_t>);
REGISTER_VECTOR_ACCUMULATE_EXT_FIELD_BACKEND("CPU", cpu_vector_accumulate<extension_t>);
REGISTER_VECTOR_SUB_EXT_FIELD_BACKEND("CPU", cpu_vector_sub<extension_t>);
Expand All @@ -1006,4 +1000,54 @@ REGISTER_SCALAR_MUL_VEC_EXT_FIELD_BACKEND("CPU", cpu_scalar_mul<extension_t>);
REGISTER_SCALAR_ADD_VEC_EXT_FIELD_BACKEND("CPU", cpu_scalar_add<extension_t>);
REGISTER_SCALAR_SUB_VEC_EXT_FIELD_BACKEND("CPU", cpu_scalar_sub<extension_t>);
REGISTER_EXECUTE_PROGRAM_EXT_FIELD_BACKEND("CPU", cpu_execute_program<extension_t>);
#endif // EXT_FIELD
#endif // EXT_FIELD

#ifdef RING
// Register APIs for rns type
REGISTER_MATRIX_TRANSPOSE_RING_RNS_BACKEND("CPU", cpu_matrix_transpose<scalar_rns_t>);
REGISTER_BIT_REVERSE_RING_RNS_BACKEND("CPU", cpu_bit_reverse<scalar_rns_t>);
REGISTER_SLICE_RING_RNS_BACKEND("CPU", cpu_slice<scalar_rns_t>);
REGISTER_VECTOR_ADD_RING_RNS_BACKEND("CPU", cpu_vector_add<scalar_rns_t>);
REGISTER_VECTOR_ACCUMULATE_RING_RNS_BACKEND("CPU", cpu_vector_accumulate<scalar_rns_t>);
REGISTER_VECTOR_SUB_RING_RNS_BACKEND("CPU", cpu_vector_sub<scalar_rns_t>);
REGISTER_VECTOR_MUL_RING_RNS_BACKEND("CPU", (cpu_vector_mul<scalar_rns_t, scalar_rns_t>));
REGISTER_VECTOR_DIV_RING_RNS_BACKEND("CPU", cpu_vector_div<scalar_rns_t>);
REGISTER_CONVERT_MONTGOMERY_RING_RNS_BACKEND("CPU", cpu_convert_montgomery<scalar_rns_t>);
REGISTER_VECTOR_SUM_RING_RNS_BACKEND("CPU", cpu_vector_sum<scalar_rns_t>);
REGISTER_VECTOR_PRODUCT_RING_RNS_BACKEND("CPU", cpu_vector_product<scalar_rns_t>);
REGISTER_SCALAR_MUL_VEC_RING_RNS_BACKEND("CPU", cpu_scalar_mul<scalar_rns_t>);
REGISTER_SCALAR_ADD_VEC_RING_RNS_BACKEND("CPU", cpu_scalar_add<scalar_rns_t>);
REGISTER_SCALAR_SUB_VEC_RING_RNS_BACKEND("CPU", cpu_scalar_sub<scalar_rns_t>);

// RNS conversion
template <typename SrcType, typename DstType, bool into_rns>
eIcicleError
cpu_convert_rns(const Device& device, const SrcType* input, uint64_t size, const VecOpsConfig& config, DstType* output)
{
tf::Taskflow taskflow;
tf::Executor executor;
const uint64_t total_nof_operations = size * config.batch_size;

const int nof_workers = get_nof_workers(config);
const uint64_t worker_task_size = (total_nof_operations + nof_workers - 1) / nof_workers; // round up

for (uint64_t start_idx = 0; start_idx < total_nof_operations; start_idx += worker_task_size) {
taskflow.emplace([=]() {
const uint64_t end_idx = std::min(start_idx + worker_task_size, total_nof_operations);
for (uint64_t idx = start_idx; idx < end_idx; ++idx) {
if constexpr (into_rns) {
DstType::convert_direct_to_rns(&input[idx].limbs_storage, &output[idx].limbs_storage);
} else {
SrcType::convert_rns_to_direct(&input[idx].limbs_storage, &output[idx].limbs_storage);
}
}
});
}

executor.run(taskflow).wait();
taskflow.clear();
return eIcicleError::SUCCESS;
}
REGISTER_CONVERT_TO_RNS_BACKEND("CPU", (cpu_convert_rns<scalar_t, scalar_rns_t, true /*into rns*/>));
REGISTER_CONVERT_FROM_RNS_BACKEND("CPU", (cpu_convert_rns<scalar_rns_t, scalar_t, false /*from rns*/>));
#endif // RING
1 change: 1 addition & 0 deletions icicle/cmake/target_editor.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ function(handle_ring TARGET)
target_sources(${TARGET} PRIVATE
src/fields/ffi_extern.cpp
src/vec_ops.cpp
src/rings/rns_vec_ops.cpp
src/matrix_ops.cpp
)
endfunction()
Expand Down
64 changes: 64 additions & 0 deletions icicle/include/icicle/backend/ntt_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,68 @@ namespace icicle {
return true; \
}(); \
}

#ifdef RING
/*************************** NTT ***************************/
using NttRingRnsImpl = std::function<eIcicleError(
const Device& device,
const scalar_rns_t* input,
int size,
NTTDir dir,
const NTTConfig<scalar_rns_t>& config,
scalar_rns_t* output)>;

void register_ring_rns_ntt(const std::string& deviceType, NttRingRnsImpl impl);

#define REGISTER_NTT_RING_RNS_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
static bool UNIQUE(_reg_ntt_ring_rns) = []() -> bool { \
register_ring_rns_ntt(DEVICE_TYPE, FUNC); \
return true; \
}(); \
}

/*************************** INIT DOMAIN ***************************/
using NttInitDomainRingRnsImpl = std::function<eIcicleError(
const Device& device, const scalar_rns_t& primitive_root, const NTTInitDomainConfig& config)>;

void register_ring_rns_ntt_init_domain(const std::string& deviceType, NttInitDomainRingRnsImpl);

#define REGISTER_NTT_INIT_DOMAIN_RING_RNS_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
static bool UNIQUE(_reg_ntt_init_domain_ring_rns) = []() -> bool { \
register_ring_rns_ntt_init_domain(DEVICE_TYPE, FUNC); \
return true; \
}(); \
}

/*************************** RELEASE DOMAIN ***************************/
// Note: 'phantom' is a workaround for the function required per field but need to differentiate by type when
// calling.
using NttReleaseDomainRingRnsImpl = std::function<eIcicleError(const Device& device, const scalar_rns_t& phantom)>;

void register_ring_rns_ntt_release_domain(const std::string& deviceType, NttReleaseDomainRingRnsImpl);

#define REGISTER_NTT_RELEASE_DOMAIN_RING_RNS_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
static bool UNIQUE(_reg_ntt_release_domain_ring_rns) = []() -> bool { \
register_ring_rns_ntt_release_domain(DEVICE_TYPE, FUNC); \
return true; \
}(); \
}

/*************************** GET ROU FROM DOMAIN ***************************/
using NttGetRouFromDomainRingRnsImpl =
std::function<eIcicleError(const Device& device, uint64_t logn, scalar_rns_t* rou)>;

void register_ring_rns_ntt_get_rou_from_domain(const std::string& deviceType, NttGetRouFromDomainRingRnsImpl);

#define REGISTER_NTT_GET_ROU_FROM_DOMAIN_RING_RNS_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
static bool UNIQUE(_reg_ntt_get_rou_from_domain_ring_rns) = []() -> bool { \
register_ring_rns_ntt_get_rou_from_domain(DEVICE_TYPE, FUNC); \
return true; \
}(); \
}
#endif // RING
} // namespace icicle
Loading
Loading