@@ -44,7 +44,7 @@ pub enum NTTDir {
44
44
/// - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order.
45
45
#[ allow( non_camel_case_types) ]
46
46
#[ repr( C ) ]
47
- #[ derive( Clone , Copy , Debug , PartialEq , Eq ) ]
47
+ #[ derive( Clone , Copy , Debug , PartialEq , Eq , PartialOrd ) ]
48
48
pub enum Ordering {
49
49
kNN,
50
50
kNR,
@@ -422,20 +422,41 @@ macro_rules! impl_ntt_bench {
422
422
$field: ident
423
423
) => {
424
424
use icicle_core:: ntt:: ntt;
425
+ use icicle_core:: ntt:: get_root_of_unity;
426
+ use icicle_core:: ntt:: initialize_domain;
425
427
use icicle_core:: ntt:: NTTDomain ;
428
+
426
429
use icicle_cuda_runtime:: memory:: HostOrDeviceSlice ;
430
+ use icicle_cuda_runtime:: device_context:: DeviceContext ;
427
431
use std:: sync:: OnceLock ;
432
+ use std:: iter:: once;
428
433
429
434
use criterion:: { black_box, criterion_group, criterion_main, Criterion } ;
430
435
use icicle_core:: {
431
436
ntt:: { FieldImpl , NTTConfig , NTTDir , NttAlgorithm , Ordering } ,
432
- traits:: ArkConvertible ,
433
437
} ;
434
438
435
439
use icicle_core:: ntt:: NTT ;
436
440
use icicle_cuda_runtime:: memory:: HostSlice ;
437
441
use icicle_core:: traits:: GenerateRandom ;
438
442
use icicle_core:: vec_ops:: VecOps ;
443
+ use std:: env;
444
+
445
+ fn get_min_max_log_size( min_log2_default: u32 , max_log2_default: u32 ) -> ( u32 , u32 ) {
446
+
447
+ fn get_env_log2( key: & str , default : u32 ) -> u32 {
448
+ env:: var( key) . unwrap_or_else( |_| default . to_string( ) ) . parse( ) . unwrap_or( default )
449
+ }
450
+
451
+ let min_log2 = get_env_log2( "MIN_LOG2" , min_log2_default) ;
452
+ let max_log2 = get_env_log2( "MAX_LOG2" , max_log2_default) ;
453
+
454
+ assert!( min_log2 >= min_log2_default, "MIN_LOG2 must be >= {}" , min_log2_default) ;
455
+ assert!( min_log2 < max_log2, "MAX_LOG2 must be > MIN_LOG2" ) ;
456
+
457
+ ( min_log2, max_log2)
458
+ }
459
+
439
460
440
461
fn ntt_for_bench<T , F : FieldImpl >(
441
462
input: & ( impl HostOrDeviceSlice <F > + ?Sized ) ,
@@ -453,6 +474,15 @@ macro_rules! impl_ntt_bench {
453
474
ntt( input, is_inverse, config, batch_ntt_result) . unwrap( ) ;
454
475
}
455
476
477
+ fn init_domain<F : FieldImpl >( max_size: u64 , device_id: usize , fast_twiddles_mode: bool )
478
+ where
479
+ <F as FieldImpl >:: Config : NTTDomain <F >,
480
+ {
481
+ let ctx = DeviceContext :: default_for_device( device_id) ;
482
+ let rou: F = get_root_of_unity( max_size) ;
483
+ initialize_domain( rou, & ctx, fast_twiddles_mode) . unwrap( ) ;
484
+ }
485
+
456
486
static INIT : OnceLock <( ) > = OnceLock :: new( ) ;
457
487
458
488
fn benchmark_ntt<T , F : FieldImpl >( c: & mut Criterion )
@@ -462,32 +492,28 @@ macro_rules! impl_ntt_bench {
462
492
{
463
493
use criterion:: SamplingMode ;
464
494
use icicle_core:: ntt:: ntt;
465
- use icicle_core:: ntt:: tests:: init_domain;
466
495
use icicle_core:: ntt:: NTTDomain ;
467
496
use icicle_cuda_runtime:: device_context:: DEFAULT_DEVICE_ID ;
468
- use std :: env ;
497
+ use icicle_cuda_runtime :: memory :: DeviceVec ;
469
498
470
499
let group_id = format!( "{} NTT" , $field_prefix) ;
471
500
let mut group = c. benchmark_group( & group_id) ;
472
501
group. sampling_mode( SamplingMode :: Flat ) ;
473
502
group. sample_size( 10 ) ;
474
503
504
+ const MIN_LOG2 : u32 = 8 ; // min length = 2 ^ MIN_LOG2
475
505
const MAX_LOG2 : u32 = 25 ; // max length = 2 ^ MAX_LOG2
476
-
477
- let max_log2 = env:: var( "MAX_LOG2" )
478
- . unwrap_or_else( |_| MAX_LOG2 . to_string( ) )
479
- . parse:: <u32 >( )
480
- . unwrap_or( MAX_LOG2 ) ;
481
-
482
506
const FAST_TWIDDLES_MODE : bool = false ;
483
507
508
+ let ( min_log2, max_log2) = get_min_max_log_size( MIN_LOG2 , MAX_LOG2 ) ;
509
+
484
510
INIT . get_or_init( move || init_domain:: <$field>( 1 << max_log2, DEFAULT_DEVICE_ID , FAST_TWIDDLES_MODE ) ) ;
485
511
486
512
let coset_generators = [ F :: one( ) , F :: Config :: generate_random( 1 ) [ 0 ] ] ;
487
513
let mut config = NTTConfig :: <F >:: default ( ) ;
488
514
489
- for test_size_log2 in ( 13u32 .. max_log2 + 1 ) {
490
- for batch_size_log2 in ( 7u32 .. 17u32 ) {
515
+ for test_size_log2 in ( min_log2..= max_log2) {
516
+ for batch_size_log2 in [ 0 , 6 , 8 , 10 ] {
491
517
let test_size = 1 << test_size_log2;
492
518
let batch_size = 1 << batch_size_log2;
493
519
let full_size = batch_size * test_size;
@@ -501,39 +527,70 @@ macro_rules! impl_ntt_bench {
501
527
502
528
let mut batch_ntt_result = vec![ F :: zero( ) ; batch_size * test_size] ;
503
529
let batch_ntt_result = HostSlice :: from_mut_slice( & mut batch_ntt_result) ;
504
- let mut config = NTTConfig :: default ( ) ;
505
- for is_inverse in [ NTTDir :: kInverse, NTTDir :: kForward] {
506
- for ordering in [
507
- Ordering :: kNN,
508
- Ordering :: kNR, // times are ~ same as kNN
509
- Ordering :: kRN,
510
- Ordering :: kRR,
511
- Ordering :: kNM,
512
- Ordering :: kMN,
513
- ] {
514
- config. ordering = ordering;
515
- // for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
516
- config. batch_size = batch_size as i32 ;
517
- // config.ntt_algorithm = alg;
518
- let bench_descr = format!(
519
- "{:?} {:?} {} x {}" ,
520
- ordering, is_inverse, test_size, batch_size
521
- ) ;
522
- group. bench_function( & bench_descr, |b| {
523
- b. iter( || {
524
- ntt_for_bench:: <F , F >(
525
- input,
526
- batch_ntt_result,
527
- test_size,
528
- batch_size,
529
- is_inverse,
530
- ordering,
531
- & mut config,
532
- black_box( 1 ) ,
533
- )
534
- } )
535
- } ) ;
536
- // }
530
+
531
+ for is_on_device in [ true , false ] {
532
+
533
+ let mut config = NTTConfig :: default ( ) ;
534
+ for is_inverse in [ NTTDir :: kInverse, NTTDir :: kForward] {
535
+ for ordering in [
536
+ Ordering :: kNN,
537
+ Ordering :: kNR, // times are ~ same as kNN
538
+ Ordering :: kRN,
539
+ Ordering :: kRR,
540
+ Ordering :: kNM,
541
+ Ordering :: kMN,
542
+ ] {
543
+ config. ordering = ordering;
544
+ for alg in [ NttAlgorithm :: Radix2 , NttAlgorithm :: MixedRadix ] {
545
+
546
+ if alg == NttAlgorithm :: Radix2 && ordering as u32 > 3 {
547
+ continue ;
548
+ }
549
+
550
+ config. batch_size = batch_size as i32 ;
551
+ config. ntt_algorithm = alg;
552
+ let bench_descr = format!(
553
+ "{} {:?} {:?} {:?} 2^ {} x {}" ,
554
+ if is_on_device { "on device" } else { "on host" } , alg, ordering, is_inverse, test_size_log2, batch_size
555
+ ) ;
556
+ if is_on_device {
557
+ let mut d_input = DeviceVec :: <F >:: cuda_malloc( full_size) . unwrap( ) ;
558
+ d_input. copy_from_host( input) . unwrap( ) ;
559
+ let mut d_batch_ntt_result = DeviceVec :: <F >:: cuda_malloc( full_size) . unwrap( ) ;
560
+ d_batch_ntt_result. copy_from_host( batch_ntt_result) . unwrap( ) ;
561
+
562
+ group. bench_function( & bench_descr, |b| {
563
+ b. iter( || {
564
+ ntt_for_bench:: <F , F >(
565
+ & d_input[ ..] ,
566
+ & mut d_batch_ntt_result[ ..] ,
567
+ test_size,
568
+ batch_size,
569
+ is_inverse,
570
+ ordering,
571
+ & mut config,
572
+ black_box( 1 ) ,
573
+ )
574
+ } )
575
+ } ) ;
576
+ } else {
577
+ group. bench_function( & bench_descr, |b| {
578
+ b. iter( || {
579
+ ntt_for_bench:: <F , F >(
580
+ input,
581
+ batch_ntt_result,
582
+ test_size,
583
+ batch_size,
584
+ is_inverse,
585
+ ordering,
586
+ & mut config,
587
+ black_box( 1 ) ,
588
+ )
589
+ } )
590
+ } ) ;
591
+ }
592
+ }
593
+ }
537
594
}
538
595
}
539
596
}
0 commit comments