-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathtests.rs
284 lines (245 loc) · 9.29 KB
/
tests.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
use crate::hash::Hasher;
use crate::program::{PreDefinedProgram, ReturningValueProgram};
use crate::sumcheck::{Sumcheck, SumcheckConfig, SumcheckProofOps, SumcheckTranscriptConfig};
use crate::traits::{FieldImpl, GenerateRandom};
use icicle_runtime::memory::{DeviceSlice, DeviceVec, HostSlice};
/// Tests the `SumcheckTranscriptConfig` struct with different constructors.
pub fn check_sumcheck_transcript_config<F: FieldImpl>(hash: &Hasher)
where
<F as FieldImpl>::Config: GenerateRandom<F>,
{
// Generate a random seed for the test.
let seed_rng = F::Config::generate_random(1)[0];
// Test `new` constructor
let config1 = SumcheckTranscriptConfig::new(
hash,
b"DomainLabel".to_vec(),
b"PolyLabel".to_vec(),
b"ChallengeLabel".to_vec(),
true, // little endian
seed_rng,
);
// Verify that the fields are correctly initialized
assert_eq!(config1.domain_separator_label, b"DomainLabel");
assert_eq!(config1.round_poly_label, b"PolyLabel");
assert_eq!(config1.round_challenge_label, b"ChallengeLabel");
assert!(config1.little_endian);
assert_eq!(config1.seed_rng, seed_rng);
// Test `from_string_labels` constructor
let config2 = SumcheckTranscriptConfig::from_string_labels(
hash,
"DomainLabel",
"PolyLabel",
"ChallengeLabel",
false, // big endian
seed_rng,
);
// Verify that the fields are correctly initialized
assert_eq!(config2.domain_separator_label, b"DomainLabel");
assert_eq!(config2.round_poly_label, b"PolyLabel");
assert_eq!(config2.round_challenge_label, b"ChallengeLabel");
assert!(!config2.little_endian);
assert_eq!(config2.seed_rng, seed_rng);
}
/// Tests the `Sumcheck` struct's basic functionality, including proving and verifying.
pub fn check_sumcheck_simple<SW, P>(hash: &Hasher)
where
SW: Sumcheck,
P: ReturningValueProgram,
{
let log_mle_poly_size = 13u64;
let mle_poly_size = 1 << log_mle_poly_size;
let nof_mle_poly = 4;
// Generate a random seed for the test.
let seed_rng = <<SW as Sumcheck>::FieldConfig>::generate_random(1)[0];
// Create a transcript configuration.
let config = SumcheckTranscriptConfig::new(
hash,
b"DomainLabel".to_vec(),
b"PolyLabel".to_vec(),
b"ChallengeLabel".to_vec(),
true, // little endian
seed_rng,
);
let mut mle_polys = Vec::with_capacity(nof_mle_poly);
for _ in 0..nof_mle_poly {
let mle_poly_random = <<SW as Sumcheck>::FieldConfig>::generate_random(mle_poly_size);
mle_polys.push(mle_poly_random);
}
let mut claimed_sum = <<SW as Sumcheck>::Field as FieldImpl>::zero();
for i in 0..mle_poly_size {
let a = mle_polys[0][i];
let b = mle_polys[1][i];
let c = mle_polys[2][i];
let eq = mle_polys[3][i];
claimed_sum = claimed_sum + (a * b - c) * eq;
}
/****** Begin CPU Proof ******/
let mle_poly_hosts = mle_polys
.iter()
.map(|poly| HostSlice::from_slice(poly))
.collect::<Vec<&HostSlice<<SW as Sumcheck>::Field>>>();
let sumcheck = SW::new().unwrap();
let combine_func = P::new_predefined(PreDefinedProgram::EQtimesABminusC).unwrap();
let sumcheck_config = SumcheckConfig::default();
// Generate a proof using the `prove` method.
let proof = sumcheck.prove(
mle_poly_hosts.as_slice(),
mle_poly_size as u64,
claimed_sum,
combine_func,
&config,
&sumcheck_config,
);
/****** End CPU Proof ******/
/****** Obtain Proof Round Polys ******/
let proof_round_polys =
<<SW as Sumcheck>::Proof as SumcheckProofOps<<SW as Sumcheck>::Field>>::get_round_polys(&proof).unwrap();
/********** Verifier deserializes proof data *********/
let proof_as_sumcheck_proof: <SW as Sumcheck>::Proof = <SW as Sumcheck>::Proof::from(proof_round_polys);
// Verify the proof.
let valid = sumcheck
.verify(&proof_as_sumcheck_proof, claimed_sum, &config)
.unwrap();
assert!(valid);
}
pub fn check_sumcheck_simple_device<SW, P>(hash: &Hasher)
where
SW: Sumcheck,
P: ReturningValueProgram,
{
let log_mle_poly_size = 13u64;
let mle_poly_size = 1 << log_mle_poly_size;
let nof_mle_poly = 4;
let seed_rng = <<SW as Sumcheck>::FieldConfig>::generate_random(1)[0];
let mut mle_polys = Vec::with_capacity(nof_mle_poly);
for _ in 0..nof_mle_poly {
let mle_poly_random = <<SW as Sumcheck>::FieldConfig>::generate_random(mle_poly_size);
mle_polys.push(mle_poly_random);
}
let mut claimed_sum = <<SW as Sumcheck>::Field as FieldImpl>::zero();
for i in 0..mle_poly_size {
let a = mle_polys[0][i];
let b = mle_polys[1][i];
let c = mle_polys[2][i];
let eq = mle_polys[3][i];
claimed_sum = claimed_sum + (a * b - c) * eq;
}
/****** Begin Device Proof ******/
let config = SumcheckTranscriptConfig::new(
hash,
b"DomainLabel".to_vec(),
b"PolyLabel".to_vec(),
b"ChallengeLabel".to_vec(),
true, // little endian
seed_rng,
);
let mle_poly_hosts = mle_polys
.iter()
.map(|poly| HostSlice::from_slice(poly))
.collect::<Vec<&HostSlice<<SW as Sumcheck>::Field>>>();
let mut device_mle_polys = Vec::with_capacity(nof_mle_poly);
for i in 0..nof_mle_poly {
let mut device_slice = DeviceVec::device_malloc(mle_poly_size).unwrap();
device_slice
.copy_from_host(mle_poly_hosts[i])
.unwrap();
device_mle_polys.push(device_slice);
}
let mle_polys_device: Vec<&DeviceSlice<<SW as Sumcheck>::Field>> = device_mle_polys
.iter()
.map(|s| &s[..])
.collect();
let device_mle_polys_slice = mle_polys_device.as_slice();
let sumcheck = SW::new().unwrap();
let combine_func = P::new_predefined(PreDefinedProgram::EQtimesABminusC).unwrap();
let sumcheck_config = SumcheckConfig::default();
// Generate a proof using the `prove` method.
let proof = sumcheck.prove(
device_mle_polys_slice,
mle_poly_size as u64,
claimed_sum,
combine_func,
&config,
&sumcheck_config,
);
/****** End Device Proof ******/
/****** Obtain Proof Round Polys ******/
let proof_round_polys =
<<SW as Sumcheck>::Proof as SumcheckProofOps<<SW as Sumcheck>::Field>>::get_round_polys(&proof).unwrap();
/********** Verifier deserializes proof data *********/
let proof_as_sumcheck_proof: <SW as Sumcheck>::Proof = <SW as Sumcheck>::Proof::from(proof_round_polys);
// Verify the proof.
let valid = sumcheck
.verify(&proof_as_sumcheck_proof, claimed_sum, &config)
.unwrap();
assert!(valid);
}
pub fn check_sumcheck_user_defined_combine<SW, P>(hash: &Hasher)
where
SW: Sumcheck,
P: ReturningValueProgram,
{
let log_mle_poly_size = 13u64;
let mle_poly_size = 1 << log_mle_poly_size;
let nof_mle_poly = 4;
// Generate a random seed for the test.
let seed_rng = <<SW as Sumcheck>::FieldConfig>::generate_random(1)[0];
// Create a transcript configuration.
let config = SumcheckTranscriptConfig::new(
hash,
b"DomainLabel".to_vec(),
b"PolyLabel".to_vec(),
b"ChallengeLabel".to_vec(),
true, // little endian
seed_rng,
);
let mut mle_polys = Vec::with_capacity(nof_mle_poly);
for _ in 0..nof_mle_poly {
let mle_poly_random = <<SW as Sumcheck>::FieldConfig>::generate_random(mle_poly_size);
mle_polys.push(mle_poly_random);
}
let mut claimed_sum = <<SW as Sumcheck>::Field as FieldImpl>::zero();
for i in 0..mle_poly_size {
let a = mle_polys[0][i];
let b = mle_polys[1][i];
let c = mle_polys[2][i];
let d = mle_polys[3][i];
claimed_sum = claimed_sum + a * b - c * SW::Field::from_u32(2) + d;
}
let user_combine = |vars: &mut Vec<P::ProgSymbol>| -> P::ProgSymbol {
let a = vars[0]; // Shallow copies pointing to the same memory in the backend
let b = vars[1];
let c = vars[2];
let d = vars[3];
return a * b + d - c * P::Field::from_u32(2);
};
/****** Begin CPU Proof ******/
let mle_poly_hosts = mle_polys
.iter()
.map(|poly| HostSlice::from_slice(poly))
.collect::<Vec<&HostSlice<<SW as Sumcheck>::Field>>>();
let sumcheck = SW::new().unwrap();
let combine_func = P::new(user_combine, /* nof_parameters = */ 4).unwrap();
let sumcheck_config = SumcheckConfig::default();
// Generate a proof using the `prove` method.
let proof = sumcheck.prove(
mle_poly_hosts.as_slice(),
mle_poly_size as u64,
claimed_sum,
combine_func,
&config,
&sumcheck_config,
);
/****** End CPU Proof ******/
/****** Obtain Proof Round Polys ******/
let proof_round_polys =
<<SW as Sumcheck>::Proof as SumcheckProofOps<<SW as Sumcheck>::Field>>::get_round_polys(&proof).unwrap();
/********** Verifier deserializes proof data *********/
let proof_as_sumcheck_proof: <SW as Sumcheck>::Proof = <SW as Sumcheck>::Proof::from(proof_round_polys);
// Verify the proof.
let valid = sumcheck
.verify(&proof_as_sumcheck_proof, claimed_sum, &config)
.unwrap();
assert!(valid);
}