Skip to content

Commit 3c815c4

Browse files
authored
Fix vecops's Rust wrapper for batch>1 (#741)
This PR fixes the Vecops rust wrapper to work with batch > 1 (Which affected most Vector operations besides vector vector operations). Added assertions when calling vecops functions in Rust to prevent invalid configurations. Minor change in CLI workflow to avoid failures when loading the CUDA backend. Signed-off-by: Koren-Brand <[email protected]>
1 parent 2c45c32 commit 3c815c4

File tree

4 files changed

+173
-49
lines changed

4 files changed

+173
-49
lines changed

docs/docs/icicle/rust-bindings/vec-ops.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub struct VecOpsConfig {
3030
- **`is_b_on_device: bool`**: Indicates whether the input b data has been preloaded on the device memory. If `false` inputs will be copied from host to device.
3131
- **`is_result_on_device: bool`**: Indicates whether the output data is preloaded in device memory. If `false` outputs will be copied from host to device.
3232
- **`is_async: bool`**: Specifies whether the NTT operation should be performed asynchronously.
33-
- **`batch_size: usize`**: Number of vector operations to process in a single batch. Each operation will be performed independently on each batch element.
33+
- **`batch_size: usize`**: Number of vector operations to process in a single batch. Each operation will be performed independently on each batch element. It is implicitly determined given the inputs and outputs to the vector operation.
3434
- **`columns_batch: bool`**: true if the batched vectors are stored as columns in a 2D array (i.e., the vectors are strided in memory as columns of a matrix). If false, the batched vectors are stored contiguously in memory (e.g., as rows or in a flat array).
3535

3636
- **`ext: ConfigExtension`**: extended configuration for backend.

icicle/tests/test_field_api.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ TYPED_TEST(FieldApiTest, Slice)
596596
const uint64_t size_in = 1 << rand_uint_32b(4, 17);
597597
const uint64_t offset = rand_uint_32b(0, 14);
598598
const uint64_t stride = rand_uint_32b(1, 4);
599-
const uint64_t size_out = rand_uint_32b(0, (size_in - offset) / stride);
599+
const uint64_t size_out = rand_uint_32b(1, std::max<uint64_t>((size_in - offset) / stride, 1));
600600
const int batch_size = 1 << rand_uint_32b(0, 4);
601601
const bool columns_batch = rand_uint_32b(0, 1);
602602

wrappers/rust/icicle-core/src/vec_ops/mod.rs

+128-13
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ pub trait MixedVecOps<F, T> {
142142
) -> Result<(), eIcicleError>;
143143
}
144144

145-
fn check_vec_ops_args<'a, F, T>(
145+
fn check_vec_ops_args<F, T>(
146146
a: &(impl HostOrDeviceSlice<F> + ?Sized),
147147
b: &(impl HostOrDeviceSlice<T> + ?Sized),
148148
result: &(impl HostOrDeviceSlice<F> + ?Sized),
@@ -156,7 +156,119 @@ fn check_vec_ops_args<'a, F, T>(
156156
result.len()
157157
);
158158
}
159+
setup_config(a, b, result, cfg, 1 /* Placeholder no need for batch_size in this operation */)
160+
}
159161

162+
fn check_vec_ops_args_scalar_ops<F, T>(
163+
a: &(impl HostOrDeviceSlice<F> + ?Sized),
164+
b: &(impl HostOrDeviceSlice<T> + ?Sized),
165+
result: &(impl HostOrDeviceSlice<F> + ?Sized),
166+
cfg: &VecOpsConfig,
167+
) -> VecOpsConfig {
168+
if b.len() != result.len() {
169+
panic!(
170+
"b.len() and result.len() do not match {} != {}",
171+
b.len(),
172+
result.len()
173+
);
174+
}
175+
if b.len() % a.len() != 0 {
176+
panic!(
177+
"b.len(), a.len() do not match {} % {} != 0",
178+
b.len(),
179+
a.len(),
180+
);
181+
}
182+
let batch_size = a.len();
183+
setup_config(a, b, result, cfg, batch_size)
184+
}
185+
186+
fn check_vec_ops_args_reduction_ops<F>(
187+
input: &(impl HostOrDeviceSlice<F> + ?Sized),
188+
result: &(impl HostOrDeviceSlice<F> + ?Sized),
189+
cfg: &VecOpsConfig,
190+
) -> VecOpsConfig {
191+
if input.len() % result.len() != 0 {
192+
panic!(
193+
"input length and result length do not match {} % {} != 0",
194+
input.len(),
195+
cfg.batch_size,
196+
);
197+
}
198+
let batch_size = result.len();
199+
setup_config(input, input, result, cfg, batch_size)
200+
}
201+
202+
fn check_vec_ops_args_transpose<F>(
203+
input: &(impl HostOrDeviceSlice<F> + ?Sized),
204+
nof_rows: u32,
205+
nof_cols: u32,
206+
output: &(impl HostOrDeviceSlice<F> + ?Sized),
207+
cfg: &VecOpsConfig,
208+
) -> VecOpsConfig {
209+
if input.len() != output.len() {
210+
panic!(
211+
"Input size, and output size do not match {} != {}",
212+
input.len(),
213+
output.len()
214+
);
215+
}
216+
if input.len() as u32 % (nof_rows * nof_cols) != 0 {
217+
panic!(
218+
"Input size is not a whole multiple of matrix size (#rows * #cols), {} % ({} * {}) != 0",
219+
input.len(),
220+
nof_rows,
221+
nof_cols,
222+
);
223+
}
224+
let batch_size = input.len() / (nof_rows * nof_cols) as usize;
225+
setup_config(input, input, output, cfg, batch_size)
226+
}
227+
228+
fn check_vec_ops_args_slice<F>(
229+
input: &(impl HostOrDeviceSlice<F> + ?Sized),
230+
offset: u64,
231+
stride: u64,
232+
size_in: u64,
233+
size_out: u64,
234+
output: &(impl HostOrDeviceSlice<F> + ?Sized),
235+
cfg: &VecOpsConfig,
236+
) -> VecOpsConfig {
237+
if input.len() as u64 % size_in != 0 {
238+
panic!(
239+
"size_in does not divide input size {} % {} != 0",
240+
input.len(),
241+
size_in,
242+
);
243+
}
244+
if output.len() as u64 % size_out != 0 {
245+
panic!(
246+
"size_out does not divide output size {} % {} != 0",
247+
output.len(),
248+
size_out,
249+
);
250+
}
251+
if offset + (size_out - 1) * stride >= size_in {
252+
panic!(
253+
"Slice exceed input size: offset + (size_out - 1) * stride >= size_in where offset={}, size_out={}, stride={}, size_in={}",
254+
offset,
255+
size_out,
256+
stride,
257+
size_in,
258+
);
259+
}
260+
let batch_size = output.len() / size_out as usize;
261+
setup_config(input, input, output, cfg, batch_size)
262+
}
263+
264+
/// Modify VecopsConfig according to the given vectors
265+
fn setup_config<F, T>(
266+
a: &(impl HostOrDeviceSlice<F> + ?Sized),
267+
b: &(impl HostOrDeviceSlice<T> + ?Sized),
268+
result: &(impl HostOrDeviceSlice<F> + ?Sized),
269+
cfg: &VecOpsConfig,
270+
batch_size: usize
271+
) -> VecOpsConfig {
160272
// check device slices are on active device
161273
if a.is_on_device() && !a.is_on_active_device() {
162274
panic!("input a is allocated on an inactive device");
@@ -169,6 +281,7 @@ fn check_vec_ops_args<'a, F, T>(
169281
}
170282

171283
let mut res_cfg = cfg.clone();
284+
res_cfg.batch_size = batch_size as i32;
172285
res_cfg.is_a_on_device = a.is_on_device();
173286
res_cfg.is_b_on_device = b.is_on_device();
174287
res_cfg.is_result_on_device = result.is_on_device();
@@ -267,7 +380,7 @@ where
267380
F: FieldImpl,
268381
<F as FieldImpl>::Config: VecOps<F>,
269382
{
270-
let cfg = check_vec_ops_args(a, a, result, cfg);
383+
let cfg = check_vec_ops_args_reduction_ops(a, result, cfg);
271384
<<F as FieldImpl>::Config as VecOps<F>>::sum(a, result, &cfg)
272385
}
273386

@@ -280,7 +393,7 @@ where
280393
F: FieldImpl,
281394
<F as FieldImpl>::Config: VecOps<F>,
282395
{
283-
let cfg = check_vec_ops_args(a, a, result, cfg);
396+
let cfg = check_vec_ops_args_reduction_ops(a, result, cfg);
284397
<<F as FieldImpl>::Config as VecOps<F>>::product(a, result, &cfg)
285398
}
286399

@@ -294,7 +407,7 @@ where
294407
F: FieldImpl,
295408
<F as FieldImpl>::Config: VecOps<F>,
296409
{
297-
let cfg = check_vec_ops_args(b, b, result, cfg);
410+
let cfg = check_vec_ops_args_scalar_ops(a, b, result, cfg);
298411
<<F as FieldImpl>::Config as VecOps<F>>::scalar_add(a, b, result, &cfg)
299412
}
300413

@@ -308,7 +421,7 @@ where
308421
F: FieldImpl,
309422
<F as FieldImpl>::Config: VecOps<F>,
310423
{
311-
let cfg = check_vec_ops_args(b, b, result, cfg);
424+
let cfg = check_vec_ops_args_scalar_ops(a, b, result, cfg);
312425
<<F as FieldImpl>::Config as VecOps<F>>::scalar_sub(a, b, result, &cfg)
313426
}
314427

@@ -322,7 +435,7 @@ where
322435
F: FieldImpl,
323436
<F as FieldImpl>::Config: VecOps<F>,
324437
{
325-
let cfg = check_vec_ops_args(b, b, result, cfg);
438+
let cfg = check_vec_ops_args_scalar_ops(a, b, result, cfg);
326439
<<F as FieldImpl>::Config as VecOps<F>>::scalar_mul(a, b, result, &cfg)
327440
}
328441

@@ -337,6 +450,7 @@ where
337450
F: FieldImpl,
338451
<F as FieldImpl>::Config: VecOps<F>,
339452
{
453+
let cfg = check_vec_ops_args_transpose(input, nof_rows, nof_cols, output, cfg);
340454
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, nof_rows, nof_cols, output, &cfg)
341455
}
342456

@@ -378,6 +492,7 @@ where
378492
F: FieldImpl,
379493
<F as FieldImpl>::Config: VecOps<F>,
380494
{
495+
let cfg = check_vec_ops_args_slice(input, offset, stride, size_in, size_out, output, cfg);
381496
<<F as FieldImpl>::Config as VecOps<F>>::slice(input, offset, stride, size_in, size_out, &cfg, output)
382497
}
383498

@@ -610,7 +725,7 @@ macro_rules! impl_vec_ops_field {
610725
unsafe {
611726
$field_prefix_ident::vector_sum_ffi(
612727
a.as_ptr(),
613-
a.len() as u32,
728+
a.len() as u32 / cfg.batch_size as u32,
614729
cfg as *const VecOpsConfig,
615730
result.as_mut_ptr(),
616731
)
@@ -626,7 +741,7 @@ macro_rules! impl_vec_ops_field {
626741
unsafe {
627742
$field_prefix_ident::vector_sum_ffi(
628743
a.as_ptr(),
629-
a.len() as u32,
744+
a.len() as u32 / cfg.batch_size as u32,
630745
cfg as *const VecOpsConfig,
631746
result.as_mut_ptr(),
632747
)
@@ -644,7 +759,7 @@ macro_rules! impl_vec_ops_field {
644759
$field_prefix_ident::scalar_add_ffi(
645760
a.as_ptr(),
646761
b.as_ptr(),
647-
b.len() as u32,
762+
b.len() as u32 / cfg.batch_size as u32,
648763
cfg as *const VecOpsConfig,
649764
result.as_mut_ptr(),
650765
)
@@ -662,7 +777,7 @@ macro_rules! impl_vec_ops_field {
662777
$field_prefix_ident::scalar_sub_ffi(
663778
a.as_ptr(),
664779
b.as_ptr(),
665-
b.len() as u32,
780+
b.len() as u32 / cfg.batch_size as u32,
666781
cfg as *const VecOpsConfig,
667782
result.as_mut_ptr(),
668783
)
@@ -680,7 +795,7 @@ macro_rules! impl_vec_ops_field {
680795
$field_prefix_ident::scalar_mul_ffi(
681796
a.as_ptr(),
682797
b.as_ptr(),
683-
b.len() as u32,
798+
b.len() as u32 / cfg.batch_size as u32,
684799
cfg as *const VecOpsConfig,
685800
result.as_mut_ptr(),
686801
)
@@ -715,7 +830,7 @@ macro_rules! impl_vec_ops_field {
715830
unsafe {
716831
$field_prefix_ident::bit_reverse_ffi(
717832
input.as_ptr(),
718-
input.len() as u64,
833+
input.len() as u64 / cfg.batch_size as u64,
719834
cfg as *const VecOpsConfig,
720835
output.as_mut_ptr(),
721836
)
@@ -730,7 +845,7 @@ macro_rules! impl_vec_ops_field {
730845
unsafe {
731846
$field_prefix_ident::bit_reverse_ffi(
732847
input.as_ptr(),
733-
input.len() as u64,
848+
input.len() as u64 / cfg.batch_size as u64,
734849
cfg as *const VecOpsConfig,
735850
input.as_mut_ptr(),
736851
)

0 commit comments

Comments
 (0)