Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom combine functions for sumcheck (ReturningValueProgram) in rust #798

Merged
merged 2 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion icicle/include/icicle/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace icicle {
}

// run over the DFG held by program_parameters and generate the program
void generate_program(std::vector<Symbol<S>>& program_parameters)
virtual void generate_program(std::vector<Symbol<S>>& program_parameters)
{
// run over the graph and allocate location for all constants
Operation<S>::reset_visit();
Expand Down
8 changes: 7 additions & 1 deletion icicle/include/icicle/program/returning_value_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace icicle {
this->set_as_inputs(program_parameters);
program_parameters[nof_inputs] = program_func(program_parameters); // place the output after the all inputs
this->generate_program(program_parameters);
m_poly_degree = program_parameters[nof_inputs].m_operation->m_poly_degree;
}

// Generate a program based on a PreDefinedPrograms
Expand All @@ -42,6 +41,13 @@ namespace icicle {
}
}

// Call base generate_program as well as updating the required polynomial degree
void generate_program(std::vector<Symbol<S>>& program_parameters) override
{
Program<S>::generate_program(program_parameters);
m_poly_degree = program_parameters.back().m_operation->m_poly_degree;
}

int get_polynomial_degree() const { return m_poly_degree; }

protected:
Expand Down
61 changes: 26 additions & 35 deletions icicle/src/program/program_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ typedef Symbol<scalar_t>* SymbolHandle;
typedef Program<scalar_t>* ProgramHandle;
typedef ReturningValueProgram<scalar_t>* ReturningValueProgramHandle;

template <typename S>
eIcicleError ffi_generate_program(Program<S>* program, Symbol<S>** parameters_ptr, int nof_parameters)
{
if (program == nullptr) { return eIcicleError::ALLOCATION_FAILED; }

std::vector<Symbol<S>> parameters_vec;
parameters_vec.reserve(nof_parameters);
for (int i = 0; i < nof_parameters; i++) {
if (parameters_ptr[i] == nullptr) { return eIcicleError::INVALID_ARGUMENT; }
parameters_vec.push_back(*parameters_ptr[i]);
}
program->m_nof_parameters = nof_parameters;
program->generate_program(parameters_vec);

ReleasePool<Symbol<S>>::instance().clear();
return eIcicleError::SUCCESS;
}

extern "C" {
// Program functions
ProgramHandle CONCAT_EXPAND(FIELD, create_predefined_program)(PreDefinedPrograms pre_def)
Expand All @@ -31,21 +49,7 @@ eIcicleError
CONCAT_EXPAND(FIELD, generate_program)(SymbolHandle* parameters_ptr, int nof_parameters, ProgramHandle* program)
{
*program = create_empty_program<scalar_t>();
std::vector<Symbol<scalar_t>> parameters_vec;
parameters_vec.reserve(nof_parameters);

for (int i = 0; i < nof_parameters; i++) {
if (parameters_ptr[i] == nullptr) { return eIcicleError::INVALID_ARGUMENT; }
parameters_vec.push_back(*parameters_ptr[i]);
}

(*program)->m_nof_parameters = nof_parameters;
(*program)->generate_program(parameters_vec);

ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.clear();

return eIcicleError::SUCCESS;
return ffi_generate_program(*program, parameters_ptr, nof_parameters);
}

ReturningValueProgramHandle CONCAT_EXPAND(FIELD, create_predefined_returning_value_program)(PreDefinedPrograms pre_def)
Expand All @@ -54,10 +58,10 @@ ReturningValueProgramHandle CONCAT_EXPAND(FIELD, create_predefined_returning_val
}

eIcicleError CONCAT_EXPAND(FIELD, generate_returning_value_program)(
SymbolHandle* parameters_ptr, int nof_parameters, ReturningValueProgramHandle* returning_program)
SymbolHandle* parameters_ptr, int nof_parameters, ReturningValueProgramHandle* program)
{
ProgramHandle program = *returning_program;
return CONCAT_EXPAND(FIELD, generate_program)(parameters_ptr, nof_parameters, &program);
*program = create_empty_returning_value_program<scalar_t>();
return ffi_generate_program(*program, parameters_ptr, nof_parameters);
}
}

Expand All @@ -77,20 +81,7 @@ eIcicleError CONCAT_EXPAND(FIELD, extension_generate_program)(
ExtensionSymbolHandle* parameters_ptr, int nof_parameters, ExtensionProgramHandle* program)
{
*program = create_empty_program<extension_t>();
std::vector<Symbol<extension_t>> parameters_vec;
parameters_vec.reserve(nof_parameters);

for (int i = 0; i < nof_parameters; i++) {
if (parameters_ptr[i] == nullptr) { return eIcicleError::INVALID_ARGUMENT; }
parameters_vec.push_back(*parameters_ptr[i]);
}
(*program)->m_nof_parameters = nof_parameters;
(*program)->generate_program(parameters_vec);

ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.clear();

return eIcicleError::SUCCESS;
return ffi_generate_program(*program, parameters_ptr, nof_parameters);
}

ExtensionReturningValueProgramHandle
Expand All @@ -100,10 +91,10 @@ CONCAT_EXPAND(FIELD, extension_create_predefined_returning_value_program)(PreDef
}

eIcicleError CONCAT_EXPAND(FIELD, extension_generate_returning_value_program)(
ExtensionSymbolHandle* parameters_ptr, int nof_parameters, ExtensionReturningValueProgramHandle* returning_program)
ExtensionSymbolHandle* parameters_ptr, int nof_parameters, ExtensionReturningValueProgramHandle* program)
{
ExtensionProgramHandle program = *returning_program;
return CONCAT_EXPAND(FIELD, extension_generate_program)(parameters_ptr, nof_parameters, &program);
*program = create_empty_returning_value_program<extension_t>();
return ffi_generate_program(*program, parameters_ptr, nof_parameters);
}
}
#endif // EXT_FIELD
42 changes: 14 additions & 28 deletions icicle/src/symbol/symbol_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,51 @@ SymbolHandle CONCAT_EXPAND(FIELD, create_input_symbol)(int in_idx)
{
auto symbol_ptr = new Symbol<scalar_t>();
symbol_ptr->set_as_input(in_idx);
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<scalar_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}
SymbolHandle CONCAT_EXPAND(FIELD, create_scalar_symbol)(const scalar_t* constant)
{
auto symbol_ptr = new Symbol<scalar_t>(*constant);
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<scalar_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}
SymbolHandle CONCAT_EXPAND(FIELD, copy_symbol)(const SymbolHandle other)
{
auto symbol_ptr = new Symbol<scalar_t>(*other);
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<scalar_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}

eIcicleError CONCAT_EXPAND(FIELD, inverse_symbol)(const SymbolHandle input, SymbolHandle* output)
{
if (!input) { return eIcicleError::INVALID_POINTER; }
*output = new Symbol<scalar_t>(input->inverse());
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(*output);
ReleasePool<Symbol<scalar_t>>::instance().insert(*output);
return eIcicleError::SUCCESS;
}

eIcicleError CONCAT_EXPAND(FIELD, add_symbols)(const SymbolHandle op_a, const SymbolHandle op_b, SymbolHandle* res)
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<scalar_t>(op_a->add(*op_b));
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<scalar_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}

eIcicleError CONCAT_EXPAND(FIELD, multiply_symbols)(const SymbolHandle op_a, const SymbolHandle op_b, SymbolHandle* res)
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<scalar_t>(op_a->multiply(*op_b));
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<scalar_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}

eIcicleError CONCAT_EXPAND(FIELD, sub_symbols)(const SymbolHandle op_a, const SymbolHandle op_b, SymbolHandle* res)
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<scalar_t>(op_a->sub(*op_b));
ReleasePool<Symbol<scalar_t>>& pool = ReleasePool<Symbol<scalar_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<scalar_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}
}
Expand All @@ -81,22 +74,19 @@ ExtensionSymbolHandle CONCAT_EXPAND(FIELD, extension_create_input_symbol)(int in
{
auto symbol_ptr = new Symbol<extension_t>();
symbol_ptr->set_as_input(in_idx);
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<extension_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}
ExtensionSymbolHandle CONCAT_EXPAND(FIELD, extension_create_scalar_symbol)(const extension_t* constant)
{
auto symbol_ptr = new Symbol<extension_t>(*constant);
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<extension_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}
ExtensionSymbolHandle CONCAT_EXPAND(FIELD, extension_copy_symbol)(const ExtensionSymbolHandle other)
{
auto symbol_ptr = new Symbol<extension_t>(*other);
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(symbol_ptr);
ReleasePool<Symbol<extension_t>>::instance().insert(symbol_ptr);
return symbol_ptr;
}

Expand All @@ -105,8 +95,7 @@ CONCAT_EXPAND(FIELD, extension_inverse_symbol)(const ExtensionSymbolHandle input
{
if (!input) { return eIcicleError::INVALID_POINTER; }
*output = new Symbol<extension_t>(input->inverse());
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(*output);
ReleasePool<Symbol<extension_t>>::instance().insert(*output);
return eIcicleError::SUCCESS;
}

Expand All @@ -115,8 +104,7 @@ eIcicleError CONCAT_EXPAND(FIELD, extension_add_symbols)(
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<extension_t>(op_a->add(*op_b));
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<extension_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}

Expand All @@ -125,8 +113,7 @@ eIcicleError CONCAT_EXPAND(FIELD, extension_multiply_symbols)(
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<extension_t>(op_a->multiply(*op_b));
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<extension_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}

Expand All @@ -135,8 +122,7 @@ eIcicleError CONCAT_EXPAND(FIELD, extension_sub_symbols)(
{
if (!op_a || !op_b) { return eIcicleError::INVALID_ARGUMENT; }
*res = new Symbol<extension_t>(op_a->sub(*op_b));
ReleasePool<Symbol<extension_t>>& pool = ReleasePool<Symbol<extension_t>>::instance();
pool.insert(*res);
ReleasePool<Symbol<extension_t>>::instance().insert(*res);
return eIcicleError::SUCCESS;
}
}
Expand Down
5 changes: 4 additions & 1 deletion wrappers/rust/icicle-core/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ macro_rules! impl_program_field {
#[link_name = concat!($field_prefix, "_generate_program")]
pub(crate) fn ffi_generate_program(parameters_ptr: *const SymbolHandle, nof_parameter: u32, program: *mut ProgramHandle) -> eIcicleError;

#[link_name = concat!($field_prefix, "_generate_returning_value_program")]
pub(crate) fn ffi_generate_returning_value_program(parameters_ptr: *const SymbolHandle, nof_parameter: u32, program: *mut ProgramHandle) -> eIcicleError;

#[link_name = "delete_program"]
pub(crate) fn ffi_delete_program(program: ProgramHandle) -> eIcicleError;
}
Expand Down Expand Up @@ -155,7 +158,7 @@ macro_rules! impl_program_field {
let mut prog_handle = std::ptr::null();
let ffi_status;
unsafe {
ffi_status = ffi_generate_program(handles.as_ptr(), program_parameters.len() as u32, &mut prog_handle);
ffi_status = ffi_generate_returning_value_program(handles.as_ptr(), program_parameters.len() as u32, &mut prog_handle);
}
if ffi_status != eIcicleError::Success {
Err(ffi_status)
Expand Down
12 changes: 12 additions & 0 deletions wrappers/rust/icicle-core/src/sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,5 +462,17 @@ macro_rules! impl_sumcheck_tests {
let device_hash = Keccak256::new(0).unwrap();
check_sumcheck_simple_device::<SumcheckWrapper, Program>(&device_hash);
}

#[test]
fn test_sumcheck_user_defined_combine() {
initialize();
test_utilities::test_set_ref_device();
let hash = Keccak256::new(0).unwrap();
check_sumcheck_user_defined_combine::<SumcheckWrapper, Program>(&hash);

test_utilities::test_set_main_device();
let device_hash = Keccak256::new(0).unwrap();
check_sumcheck_user_defined_combine::<SumcheckWrapper, Program>(&device_hash);
}
};
}
79 changes: 79 additions & 0 deletions wrappers/rust/icicle-core/src/sumcheck/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,82 @@ where

assert!(valid);
}

pub fn check_sumcheck_user_defined_combine<SW, P>(hash: &Hasher)
where
SW: Sumcheck,
P: ReturningValueProgram,
{
let log_mle_poly_size = 13u64;
let mle_poly_size = 1 << log_mle_poly_size;
let nof_mle_poly = 4;
// Generate a random seed for the test.
let seed_rng = <<SW as Sumcheck>::FieldConfig>::generate_random(1)[0];

// Create a transcript configuration.
let config = SumcheckTranscriptConfig::new(
hash,
b"DomainLabel".to_vec(),
b"PolyLabel".to_vec(),
b"ChallengeLabel".to_vec(),
true, // little endian
seed_rng,
);

let mut mle_polys = Vec::with_capacity(nof_mle_poly);
for _ in 0..nof_mle_poly {
let mle_poly_random = <<SW as Sumcheck>::FieldConfig>::generate_random(mle_poly_size);
mle_polys.push(mle_poly_random);
}

let mut claimed_sum = <<SW as Sumcheck>::Field as FieldImpl>::zero();
for i in 0..mle_poly_size {
let a = mle_polys[0][i];
let b = mle_polys[1][i];
let c = mle_polys[2][i];
let d = mle_polys[3][i];
claimed_sum = claimed_sum + a * b - c * SW::Field::from_u32(2) + d;
}

let user_combine = |vars: &mut Vec<P::ProgSymbol>| -> P::ProgSymbol {
let a = vars[0]; // Shallow copies pointing to the same memory in the backend
let b = vars[1];
let c = vars[2];
let d = vars[3];
return a * b + d - c * P::Field::from_u32(2);
};

/****** Begin CPU Proof ******/
let mle_poly_hosts = mle_polys
.iter()
.map(|poly| HostSlice::from_slice(poly))
.collect::<Vec<&HostSlice<<SW as Sumcheck>::Field>>>();

let sumcheck = SW::new().unwrap();
let combine_func = P::new(user_combine, /* nof_parameters = */ 4).unwrap();
let sumcheck_config = SumcheckConfig::default();
// Generate a proof using the `prove` method.
let proof = sumcheck.prove(
mle_poly_hosts.as_slice(),
mle_poly_size as u64,
claimed_sum,
combine_func,
&config,
&sumcheck_config,
);
/****** End CPU Proof ******/

/****** Obtain Proof Round Polys ******/
let proof_round_polys =
<<SW as Sumcheck>::Proof as SumcheckProofOps<<SW as Sumcheck>::Field>>::get_round_polys(&proof).unwrap();

/********** Verifier deserializes proof data *********/
let proof_as_sumcheck_proof: <SW as Sumcheck>::Proof = <SW as Sumcheck>::Proof::from(proof_round_polys);

// Verify the proof.
let valid = sumcheck
.verify(&proof_as_sumcheck_proof, claimed_sum, &config)
.unwrap();

assert!(valid);
}
Loading