@@ -142,7 +142,7 @@ pub trait MixedVecOps<F, T> {
142
142
) -> Result < ( ) , eIcicleError > ;
143
143
}
144
144
145
- fn check_vec_ops_args < ' a , F , T > (
145
+ fn check_vec_ops_args < F , T > (
146
146
a : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
147
147
b : & ( impl HostOrDeviceSlice < T > + ?Sized ) ,
148
148
result : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
@@ -156,7 +156,119 @@ fn check_vec_ops_args<'a, F, T>(
156
156
result. len( )
157
157
) ;
158
158
}
159
+ setup_config ( a, b, result, cfg, 1 /* Placeholder no need for batch_size in this operation */ )
160
+ }
159
161
162
+ fn check_vec_ops_args_scalar_ops < F , T > (
163
+ a : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
164
+ b : & ( impl HostOrDeviceSlice < T > + ?Sized ) ,
165
+ result : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
166
+ cfg : & VecOpsConfig ,
167
+ ) -> VecOpsConfig {
168
+ if b. len ( ) != result. len ( ) {
169
+ panic ! (
170
+ "b.len() and result.len() do not match {} != {}" ,
171
+ b. len( ) ,
172
+ result. len( )
173
+ ) ;
174
+ }
175
+ if b. len ( ) % a. len ( ) != 0 {
176
+ panic ! (
177
+ "b.len(), a.len() do not match {} % {} != 0" ,
178
+ b. len( ) ,
179
+ a. len( ) ,
180
+ ) ;
181
+ }
182
+ let batch_size = a. len ( ) ;
183
+ setup_config ( a, b, result, cfg, batch_size)
184
+ }
185
+
186
+ fn check_vec_ops_args_reduction_ops < F > (
187
+ input : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
188
+ result : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
189
+ cfg : & VecOpsConfig ,
190
+ ) -> VecOpsConfig {
191
+ if input. len ( ) % result. len ( ) != 0 {
192
+ panic ! (
193
+ "input length and result length do not match {} % {} != 0" ,
194
+ input. len( ) ,
195
+ cfg. batch_size,
196
+ ) ;
197
+ }
198
+ let batch_size = result. len ( ) ;
199
+ setup_config ( input, input, result, cfg, batch_size)
200
+ }
201
+
202
+ fn check_vec_ops_args_transpose < F > (
203
+ input : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
204
+ nof_rows : u32 ,
205
+ nof_cols : u32 ,
206
+ output : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
207
+ cfg : & VecOpsConfig ,
208
+ ) -> VecOpsConfig {
209
+ if input. len ( ) != output. len ( ) {
210
+ panic ! (
211
+ "Input size, and output size do not match {} != {}" ,
212
+ input. len( ) ,
213
+ output. len( )
214
+ ) ;
215
+ }
216
+ if input. len ( ) as u32 % ( nof_rows * nof_cols) != 0 {
217
+ panic ! (
218
+ "Input size is not a whole multiple of matrix size (#rows * #cols), {} % ({} * {}) != 0" ,
219
+ input. len( ) ,
220
+ nof_rows,
221
+ nof_cols,
222
+ ) ;
223
+ }
224
+ let batch_size = input. len ( ) / ( nof_rows * nof_cols) as usize ;
225
+ setup_config ( input, input, output, cfg, batch_size)
226
+ }
227
+
228
+ fn check_vec_ops_args_slice < F > (
229
+ input : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
230
+ offset : u64 ,
231
+ stride : u64 ,
232
+ size_in : u64 ,
233
+ size_out : u64 ,
234
+ output : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
235
+ cfg : & VecOpsConfig ,
236
+ ) -> VecOpsConfig {
237
+ if input. len ( ) as u64 % size_in != 0 {
238
+ panic ! (
239
+ "size_in does not divide input size {} % {} != 0" ,
240
+ input. len( ) ,
241
+ size_in,
242
+ ) ;
243
+ }
244
+ if output. len ( ) as u64 % size_out != 0 {
245
+ panic ! (
246
+ "size_out does not divide output size {} % {} != 0" ,
247
+ output. len( ) ,
248
+ size_out,
249
+ ) ;
250
+ }
251
+ if offset + ( size_out - 1 ) * stride >= size_in {
252
+ panic ! (
253
+ "Slice exceed input size: offset + (size_out - 1) * stride >= size_in where offset={}, size_out={}, stride={}, size_in={}" ,
254
+ offset,
255
+ size_out,
256
+ stride,
257
+ size_in,
258
+ ) ;
259
+ }
260
+ let batch_size = output. len ( ) / size_out as usize ;
261
+ setup_config ( input, input, output, cfg, batch_size)
262
+ }
263
+
264
+ /// Modify VecopsConfig according to the given vectors
265
+ fn setup_config < F , T > (
266
+ a : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
267
+ b : & ( impl HostOrDeviceSlice < T > + ?Sized ) ,
268
+ result : & ( impl HostOrDeviceSlice < F > + ?Sized ) ,
269
+ cfg : & VecOpsConfig ,
270
+ batch_size : usize
271
+ ) -> VecOpsConfig {
160
272
// check device slices are on active device
161
273
if a. is_on_device ( ) && !a. is_on_active_device ( ) {
162
274
panic ! ( "input a is allocated on an inactive device" ) ;
@@ -169,6 +281,7 @@ fn check_vec_ops_args<'a, F, T>(
169
281
}
170
282
171
283
let mut res_cfg = cfg. clone ( ) ;
284
+ res_cfg. batch_size = batch_size as i32 ;
172
285
res_cfg. is_a_on_device = a. is_on_device ( ) ;
173
286
res_cfg. is_b_on_device = b. is_on_device ( ) ;
174
287
res_cfg. is_result_on_device = result. is_on_device ( ) ;
@@ -267,7 +380,7 @@ where
267
380
F : FieldImpl ,
268
381
<F as FieldImpl >:: Config : VecOps < F > ,
269
382
{
270
- let cfg = check_vec_ops_args ( a , a, result, cfg) ;
383
+ let cfg = check_vec_ops_args_reduction_ops ( a, result, cfg) ;
271
384
<<F as FieldImpl >:: Config as VecOps < F > >:: sum ( a, result, & cfg)
272
385
}
273
386
@@ -280,7 +393,7 @@ where
280
393
F : FieldImpl ,
281
394
<F as FieldImpl >:: Config : VecOps < F > ,
282
395
{
283
- let cfg = check_vec_ops_args ( a , a, result, cfg) ;
396
+ let cfg = check_vec_ops_args_reduction_ops ( a, result, cfg) ;
284
397
<<F as FieldImpl >:: Config as VecOps < F > >:: product ( a, result, & cfg)
285
398
}
286
399
@@ -294,7 +407,7 @@ where
294
407
F : FieldImpl ,
295
408
<F as FieldImpl >:: Config : VecOps < F > ,
296
409
{
297
- let cfg = check_vec_ops_args ( b , b, result, cfg) ;
410
+ let cfg = check_vec_ops_args_scalar_ops ( a , b, result, cfg) ;
298
411
<<F as FieldImpl >:: Config as VecOps < F > >:: scalar_add ( a, b, result, & cfg)
299
412
}
300
413
@@ -308,7 +421,7 @@ where
308
421
F : FieldImpl ,
309
422
<F as FieldImpl >:: Config : VecOps < F > ,
310
423
{
311
- let cfg = check_vec_ops_args ( b , b, result, cfg) ;
424
+ let cfg = check_vec_ops_args_scalar_ops ( a , b, result, cfg) ;
312
425
<<F as FieldImpl >:: Config as VecOps < F > >:: scalar_sub ( a, b, result, & cfg)
313
426
}
314
427
@@ -322,7 +435,7 @@ where
322
435
F : FieldImpl ,
323
436
<F as FieldImpl >:: Config : VecOps < F > ,
324
437
{
325
- let cfg = check_vec_ops_args ( b , b, result, cfg) ;
438
+ let cfg = check_vec_ops_args_scalar_ops ( a , b, result, cfg) ;
326
439
<<F as FieldImpl >:: Config as VecOps < F > >:: scalar_mul ( a, b, result, & cfg)
327
440
}
328
441
@@ -337,6 +450,7 @@ where
337
450
F : FieldImpl ,
338
451
<F as FieldImpl >:: Config : VecOps < F > ,
339
452
{
453
+ let cfg = check_vec_ops_args_transpose ( input, nof_rows, nof_cols, output, cfg) ;
340
454
<<F as FieldImpl >:: Config as VecOps < F > >:: transpose ( input, nof_rows, nof_cols, output, & cfg)
341
455
}
342
456
@@ -378,6 +492,7 @@ where
378
492
F : FieldImpl ,
379
493
<F as FieldImpl >:: Config : VecOps < F > ,
380
494
{
495
+ let cfg = check_vec_ops_args_slice ( input, offset, stride, size_in, size_out, output, cfg) ;
381
496
<<F as FieldImpl >:: Config as VecOps < F > >:: slice ( input, offset, stride, size_in, size_out, & cfg, output)
382
497
}
383
498
@@ -610,7 +725,7 @@ macro_rules! impl_vec_ops_field {
610
725
unsafe {
611
726
$field_prefix_ident:: vector_sum_ffi(
612
727
a. as_ptr( ) ,
613
- a. len( ) as u32 ,
728
+ a. len( ) as u32 / cfg . batch_size as u32 ,
614
729
cfg as * const VecOpsConfig ,
615
730
result. as_mut_ptr( ) ,
616
731
)
@@ -626,7 +741,7 @@ macro_rules! impl_vec_ops_field {
626
741
unsafe {
627
742
$field_prefix_ident:: vector_sum_ffi(
628
743
a. as_ptr( ) ,
629
- a. len( ) as u32 ,
744
+ a. len( ) as u32 / cfg . batch_size as u32 ,
630
745
cfg as * const VecOpsConfig ,
631
746
result. as_mut_ptr( ) ,
632
747
)
@@ -644,7 +759,7 @@ macro_rules! impl_vec_ops_field {
644
759
$field_prefix_ident:: scalar_add_ffi(
645
760
a. as_ptr( ) ,
646
761
b. as_ptr( ) ,
647
- b. len( ) as u32 ,
762
+ b. len( ) as u32 / cfg . batch_size as u32 ,
648
763
cfg as * const VecOpsConfig ,
649
764
result. as_mut_ptr( ) ,
650
765
)
@@ -662,7 +777,7 @@ macro_rules! impl_vec_ops_field {
662
777
$field_prefix_ident:: scalar_sub_ffi(
663
778
a. as_ptr( ) ,
664
779
b. as_ptr( ) ,
665
- b. len( ) as u32 ,
780
+ b. len( ) as u32 / cfg . batch_size as u32 ,
666
781
cfg as * const VecOpsConfig ,
667
782
result. as_mut_ptr( ) ,
668
783
)
@@ -680,7 +795,7 @@ macro_rules! impl_vec_ops_field {
680
795
$field_prefix_ident:: scalar_mul_ffi(
681
796
a. as_ptr( ) ,
682
797
b. as_ptr( ) ,
683
- b. len( ) as u32 ,
798
+ b. len( ) as u32 / cfg . batch_size as u32 ,
684
799
cfg as * const VecOpsConfig ,
685
800
result. as_mut_ptr( ) ,
686
801
)
@@ -715,7 +830,7 @@ macro_rules! impl_vec_ops_field {
715
830
unsafe {
716
831
$field_prefix_ident:: bit_reverse_ffi(
717
832
input. as_ptr( ) ,
718
- input. len( ) as u64 ,
833
+ input. len( ) as u64 / cfg . batch_size as u64 ,
719
834
cfg as * const VecOpsConfig ,
720
835
output. as_mut_ptr( ) ,
721
836
)
@@ -730,7 +845,7 @@ macro_rules! impl_vec_ops_field {
730
845
unsafe {
731
846
$field_prefix_ident:: bit_reverse_ffi(
732
847
input. as_ptr( ) ,
733
- input. len( ) as u64 ,
848
+ input. len( ) as u64 / cfg . batch_size as u64 ,
734
849
cfg as * const VecOpsConfig ,
735
850
input. as_mut_ptr( ) ,
736
851
)
0 commit comments