@@ -368,33 +368,39 @@ public AUUCImpl(double[] thresholds, int nbins, double[] probs) {
368
368
*/
369
369
public static class AUUCBuilder extends Iced {
370
370
final int _nbins ;
371
- final double []_thresholds ; // thresholds
371
+ final double [] _thresholds ; // thresholds
372
372
final long [] _treatment ; // number of data from treatment group
373
373
final long [] _control ; // number of data from control group
374
374
final long [] _yTreatment ; // number of data from treatment group with prediction = 1
375
375
final long [] _yControl ; // number of data from control group with prediction = 1
376
376
final long [] _frequency ; // frequency of data in each bin
377
377
double [] _probs ;
378
378
int _n ; // number of data
379
- int _nUsed ; // number of used bins
379
+ int _nbinsUsed ; // number of used bins
380
380
int _ssx ;
381
381
382
382
public AUUCBuilder (int nbins , double [] thresholds , double [] probs ) {
383
- int tlen = thresholds != null ? thresholds .length : 1 ;
384
383
_probs = probs ;
385
384
_nbins = nbins ;
386
- _nUsed = tlen ;
387
- _thresholds = thresholds == null ? new double []{0 } : thresholds ;
388
- _treatment = new long [tlen ];
389
- _control = new long [tlen ];
390
- _yTreatment = new long [tlen ];
391
- _yControl = new long [tlen ];
392
- _frequency = new long [tlen ];
385
+ _nbinsUsed = thresholds != null ? thresholds .length : 0 ;
386
+ int l = nbins * 2 ; // maximal possible builder arrays length
387
+ _thresholds = new double [l ];
388
+ if (thresholds != null ) {
389
+ System .arraycopy (thresholds , 0 , _thresholds , 0 , thresholds .length );
390
+ }
391
+ _probs = new double [l ];
392
+ System .arraycopy (probs , 0 , _probs , 0 , probs .length );
393
+ System .arraycopy (probs , 0 , _probs , probs .length -1 , probs .length );
394
+ _treatment = new long [l ];
395
+ _control = new long [l ];
396
+ _yTreatment = new long [l ];
397
+ _yControl = new long [l ];
398
+ _frequency = new long [l ];
393
399
_ssx = -1 ;
394
400
}
395
401
396
402
public void perRow (double pred , double w , double y , float treatment ) {
397
- if (w == 0 ) {return ;}
403
+ if (w == 0 || _thresholds == null ) {return ;}
398
404
for (int t = 0 ; t < _thresholds .length ; t ++) {
399
405
if (pred >= _thresholds [t ] && (t == 0 || pred <_thresholds [t -1 ])) {
400
406
_n ++;
@@ -416,20 +422,23 @@ public void perRow(double pred, double w, double y, float treatment) {
416
422
}
417
423
418
424
public void reduce (AUUCBuilder bldr ) {
419
- _n += bldr . _n ;
420
- ArrayUtils . add ( _treatment , bldr ._treatment );
421
- ArrayUtils . add ( _control , bldr . _control );
422
- ArrayUtils . add ( _yTreatment , bldr . _yTreatment );
423
- ArrayUtils . add ( _yControl , bldr . _yControl );
424
- ArrayUtils . add ( _frequency , bldr . _frequency );
425
+ if ( bldr . _nbinsUsed == 0 ) { return ;}
426
+ if ( _nbinsUsed == 0 || _thresholds == bldr ._thresholds ){
427
+ reduceSameOrNullThresholds ( bldr );
428
+ } else {
429
+ reduceDifferentThresholds ( bldr );
430
+ }
425
431
}
426
432
427
- public void reduce2 (AUUCBuilder bldr ) {
428
- // Merge sort the 2 sorted lists into the double-sized arrays. The tail
429
- // half of the double-sized array is unused, but the front half is
430
- // probably a source. Merge into the back.
431
- int x = _n -1 ;
432
- int y = bldr ._n -1 ;
433
+ /**
434
+ * Merge sort the 2 sorted lists into the double-sized arrays. The tail
435
+ * half of the double-sized array is unused, but the front half is
436
+ * probably a source. Merge into the back.
437
+ * @param bldr AUUC builder to reduce
438
+ */
439
+ public void reduceDifferentThresholds (AUUCBuilder bldr ){
440
+ int x = _nbinsUsed -1 ;
441
+ int y = bldr ._nbinsUsed -1 ;
433
442
while ( x +y +1 >= 0 ) {
434
443
boolean self_is_larger = y < 0 || (x >= 0 && _thresholds [x ] >= bldr ._thresholds [y ]);
435
444
AUUCBuilder b = self_is_larger ? this : bldr ;
@@ -440,16 +449,31 @@ public void reduce2(AUUCBuilder bldr) {
440
449
_yTreatment [x +y +1 ] = b ._yTreatment [idx ];
441
450
_yControl [x +y +1 ] = b ._yControl [idx ];
442
451
_frequency [x +y +1 ] = b ._frequency [idx ];
452
+ _probs [x +y +1 ] = b ._probs [idx ];
443
453
if ( self_is_larger ) x --; else y --;
444
454
}
445
455
_n += bldr ._n ;
456
+ _nbinsUsed += bldr ._nbinsUsed ;
446
457
_ssx = -1 ;
447
458
448
459
// Merge elements with least squared-error increase until we get fewer
449
460
// than _nBins and no duplicates. May require many merges.
450
- while ( _n > _nbins || dups () )
461
+ while ( _nbinsUsed > _nbins || dups () )
451
462
mergeOneBin ();
452
463
}
464
+
465
+ public void reduceSameOrNullThresholds (AUUCBuilder bldr ){
466
+ _n += bldr ._n ;
467
+ if (_nbinsUsed == 0 ) {
468
+ ArrayUtils .add (_thresholds , bldr ._thresholds );
469
+ _nbinsUsed = bldr ._nbinsUsed ;
470
+ }
471
+ ArrayUtils .add (_treatment , bldr ._treatment );
472
+ ArrayUtils .add (_control , bldr ._control );
473
+ ArrayUtils .add (_yTreatment , bldr ._yTreatment );
474
+ ArrayUtils .add (_yControl , bldr ._yControl );
475
+ ArrayUtils .add (_frequency , bldr ._frequency );
476
+ }
453
477
454
478
static double combineCenters (double ths1 , double ths0 , double probs , long nrows ) {
455
479
//double center = (ths0 * n0 + ths1 * n1) / (n0 + n1);
@@ -474,26 +498,22 @@ private void mergeOneBin( ) {
474
498
_yTreatment [ssx ] += _yTreatment [ssx +1 ];
475
499
_yControl [ssx ] += _yControl [ssx +1 ];
476
500
_frequency [ssx ] += _frequency [ssx +1 ];
477
- int n = _n ;
501
+ int n = _nbinsUsed == 2 ? _nbinsUsed - ssx - 1 : _nbinsUsed - ssx - 2 ;
478
502
// Slide over to crush the removed bin at index (ssx+1)
479
- System .arraycopy (_thresholds ,ssx +2 ,_thresholds ,ssx +1 ,n -ssx -2 );
480
- System .arraycopy (_treatment ,ssx +2 ,_treatment ,ssx +1 ,n -ssx -2 );
481
- System .arraycopy (_control ,ssx +2 ,_control ,ssx +1 ,n -ssx -2 );
482
- System .arraycopy (_yTreatment ,ssx +2 ,_yTreatment ,ssx +1 ,n -ssx -2 );
483
- System .arraycopy (_yControl ,ssx +2 ,_yControl ,ssx +1 ,n -ssx -2 );
484
- System .arraycopy (_frequency ,ssx +2 ,_frequency ,ssx +1 ,n -ssx -2 );
485
- _n --;
486
- _thresholds [_n ] = _treatment [_n ] = _control [_n ] = _yTreatment [_n ] = _yControl [_n ] = _frequency [_n ] = 0 ;
503
+ System .arraycopy (_thresholds ,ssx +2 ,_thresholds ,ssx +1 ,n );
504
+ System .arraycopy (_treatment ,ssx +2 ,_treatment ,ssx +1 ,n );
505
+ System .arraycopy (_control ,ssx +2 ,_control ,ssx +1 ,n );
506
+ System .arraycopy (_yTreatment ,ssx +2 ,_yTreatment ,ssx +1 ,n );
507
+ System .arraycopy (_yControl ,ssx +2 ,_yControl ,ssx +1 ,n );
508
+ System .arraycopy (_frequency ,ssx +2 ,_frequency ,ssx +1 ,n );
509
+ _nbinsUsed --;
487
510
_ssx = -1 ;
488
511
}
489
512
490
- // Find the pair of bins that when combined give the smallest increase in
491
- // squared error. Dups never increase squared error.
492
- //
493
- // I tried code for merging bins with keeping the bins balanced in size,
494
- // but this leads to bad errors if the probabilities are sorted. Also
495
- // tried the original: merge bins with the least distance between bin
496
- // centers. Same problem for sorted data.
513
+ /**
514
+ * Find the pair of bins that when combined give the smallest difference in thresholds
515
+ * @return index of the bin where the threshold difference is the smallest
516
+ */
497
517
private int findSmallest () {
498
518
if ( _ssx == -1 ) {
499
519
_ssx = findSmallestImpl ();
@@ -503,12 +523,10 @@ private int findSmallest() {
503
523
}
504
524
505
525
private int findSmallestImpl () {
506
- if (_n == 1 )
526
+ if (_nbinsUsed == 1 )
507
527
return 0 ;
508
- // we couldn't find any bins to merge based on SE (the math can be producing Double.Infinity or Double.NaN)
509
- // revert to using a simple distance of the bin centers
510
528
int minI = 0 ;
511
- long n = _n ;
529
+ long n = _nbinsUsed ;
512
530
double minDist = _thresholds [1 ] - _thresholds [0 ];
513
531
for (int i = 1 ; i < n - 1 ; i ++) {
514
532
double dist = _thresholds [i + 1 ] - _thresholds [i ];
@@ -521,25 +539,27 @@ private int findSmallestImpl() {
521
539
}
522
540
523
541
private boolean dups () {
524
- long n = _n ;
542
+ long n = _nbinsUsed ;
525
543
for ( int i =0 ; i <n -1 ; i ++ ) {
526
- double derr = computeDeltaError (_thresholds [i +1 ],_frequency [i +1 ],_thresholds [i ],_frequency [i ]);
527
- if ( derr == 0 ) { _ssx = i ; return true ; }
544
+ double derr = computeDeltaError (_thresholds [i + 1 ], _frequency [i + 1 ], _thresholds [i ], _frequency [i ]);
545
+ if (derr == 0 ) {
546
+ _ssx = i ;
547
+ return true ;
548
+ }
528
549
}
529
550
return false ;
530
551
}
531
552
532
-
553
+ /**
554
+ * If thresholds vary by less than a float ULP, treat them as the same.
555
+ * Parallel equation drawn from:
556
+ * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
557
+ * @return delta error from two thresholds
558
+ */
533
559
private double computeDeltaError (double ths1 , double n1 , double ths0 , double n0 ) {
534
- // If thresholds vary by less than a float ULP, treat them as the same.
535
- // Some models only output predictions to within float accuracy (so a
536
- // variance here is junk), and also it's not statistically sane to have
537
- // a model which varies predictions by such a tiny change in thresholds.
538
560
double delta = (float )ths1 -(float )ths0 ;
539
561
if (delta == 0 )
540
562
return 0 ;
541
- // Parallel equation drawn from:
542
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
543
563
return delta *delta *n0 *n1 / (n0 +n1 );
544
564
}
545
565
@@ -555,7 +575,8 @@ private static double computeLinearInterpolation(double ths1, double ths0, doubl
555
575
556
576
private String toDebugString () {
557
577
return "n =" +_n +
558
- "; nBins = " + _nbins +
578
+ "; nbins = " + _nbins +
579
+ "; nbinsUsed = " + _nbinsUsed +
559
580
"; ths = " + Arrays .toString (_thresholds ) +
560
581
"; treatment = " + Arrays .toString (_treatment ) +
561
582
"; contribution = " + Arrays .toString (_control ) +
0 commit comments