Skip to content

Commit aacb1db

Browse files
committed
Implement AUUC reduce correctly
1 parent ab932e8 commit aacb1db

File tree

2 files changed

+77
-56
lines changed

2 files changed

+77
-56
lines changed

h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ public void testSupportCVCriteo() {
353353
p._treatment_column = "treatment";
354354
p._response_column = "conversion";
355355
p._seed = 0xDECAF;
356-
p._ntrees = 10;
356+
p._ntrees = 11;
357357
p._score_each_iteration = true;
358358
p._nfolds = 3;
359359
p._auuc_nbins = 50;

h2o-core/src/main/java/hex/AUUC.java

+76-55
Original file line numberDiff line numberDiff line change
@@ -368,33 +368,39 @@ public AUUCImpl(double[] thresholds, int nbins, double[] probs) {
368368
*/
369369
public static class AUUCBuilder extends Iced {
370370
final int _nbins;
371-
final double[]_thresholds; // thresholds
371+
final double[] _thresholds; // thresholds
372372
final long[] _treatment; // number of data from treatment group
373373
final long[] _control; // number of data from control group
374374
final long[] _yTreatment; // number of data from treatment group with prediction = 1
375375
final long[] _yControl; // number of data from control group with prediction = 1
376376
final long[] _frequency; // frequency of data in each bin
377377
double[] _probs;
378378
int _n; // number of data
379-
int _nUsed; // number of used bins
379+
int _nbinsUsed; // number of used bins
380380
int _ssx;
381381

382382
public AUUCBuilder(int nbins, double[] thresholds, double[] probs) {
383-
int tlen = thresholds != null ? thresholds.length : 1;
384383
_probs = probs;
385384
_nbins = nbins;
386-
_nUsed = tlen;
387-
_thresholds = thresholds == null ? new double[]{0} : thresholds;
388-
_treatment = new long[tlen];
389-
_control = new long[tlen];
390-
_yTreatment = new long[tlen];
391-
_yControl = new long[tlen];
392-
_frequency = new long[tlen];
385+
_nbinsUsed = thresholds != null ? thresholds.length : 0;
386+
int l = nbins * 2; // maximal possible builder arrays length
387+
_thresholds = new double[l];
388+
if (thresholds != null) {
389+
System.arraycopy(thresholds, 0, _thresholds, 0, thresholds.length);
390+
}
391+
_probs = new double[l];
392+
System.arraycopy(probs, 0, _probs, 0, probs.length);
393+
System.arraycopy(probs, 0, _probs, probs.length-1, probs.length);
394+
_treatment = new long[l];
395+
_control = new long[l];
396+
_yTreatment = new long[l];
397+
_yControl = new long[l];
398+
_frequency = new long[l];
393399
_ssx = -1;
394400
}
395401

396402
public void perRow(double pred, double w, double y, float treatment) {
397-
if (w == 0) {return;}
403+
if (w == 0 || _thresholds == null) {return;}
398404
for(int t = 0; t < _thresholds.length; t++) {
399405
if (pred >= _thresholds[t] && (t == 0 || pred <_thresholds[t-1])) {
400406
_n++;
@@ -416,20 +422,23 @@ public void perRow(double pred, double w, double y, float treatment) {
416422
}
417423

418424
public void reduce(AUUCBuilder bldr) {
419-
_n += bldr._n;
420-
ArrayUtils.add(_treatment, bldr._treatment);
421-
ArrayUtils.add(_control, bldr._control);
422-
ArrayUtils.add(_yTreatment, bldr._yTreatment);
423-
ArrayUtils.add(_yControl, bldr._yControl);
424-
ArrayUtils.add(_frequency, bldr._frequency);
425+
if(bldr._nbinsUsed == 0) {return;}
426+
if(_nbinsUsed == 0 || _thresholds == bldr._thresholds){
427+
reduceSameOrNullThresholds(bldr);
428+
} else {
429+
reduceDifferentThresholds(bldr);
430+
}
425431
}
426432

427-
public void reduce2(AUUCBuilder bldr) {
428-
// Merge sort the 2 sorted lists into the double-sized arrays. The tail
429-
// half of the double-sized array is unused, but the front half is
430-
// probably a source. Merge into the back.
431-
int x = _n-1;
432-
int y = bldr._n-1;
433+
/**
434+
* Merge sort the 2 sorted lists into the double-sized arrays. The tail
435+
* half of the double-sized array is unused, but the front half is
436+
* probably a source. Merge into the back.
437+
* @param bldr AUUC builder to reduce
438+
*/
439+
public void reduceDifferentThresholds(AUUCBuilder bldr){
440+
int x = _nbinsUsed -1;
441+
int y = bldr._nbinsUsed -1;
433442
while( x+y+1 >= 0 ) {
434443
boolean self_is_larger = y < 0 || (x >= 0 && _thresholds[x] >= bldr._thresholds[y]);
435444
AUUCBuilder b = self_is_larger ? this : bldr;
@@ -440,16 +449,31 @@ public void reduce2(AUUCBuilder bldr) {
440449
_yTreatment[x+y+1] = b._yTreatment[idx];
441450
_yControl[x+y+1] = b._yControl[idx];
442451
_frequency[x+y+1] = b._frequency[idx];
452+
_probs[x+y+1] = b._probs[idx];
443453
if( self_is_larger ) x--; else y--;
444454
}
445455
_n += bldr._n;
456+
_nbinsUsed += bldr._nbinsUsed;
446457
_ssx = -1;
447458

448459
// Merge elements with least squared-error increase until we get fewer
449460
// than _nBins and no duplicates. May require many merges.
450-
while( _n > _nbins || dups() )
461+
while( _nbinsUsed > _nbins || dups() )
451462
mergeOneBin();
452463
}
464+
465+
public void reduceSameOrNullThresholds(AUUCBuilder bldr){
466+
_n += bldr._n;
467+
if(_nbinsUsed == 0) {
468+
ArrayUtils.add(_thresholds, bldr._thresholds);
469+
_nbinsUsed = bldr._nbinsUsed;
470+
}
471+
ArrayUtils.add(_treatment, bldr._treatment);
472+
ArrayUtils.add(_control, bldr._control);
473+
ArrayUtils.add(_yTreatment, bldr._yTreatment);
474+
ArrayUtils.add(_yControl, bldr._yControl);
475+
ArrayUtils.add(_frequency, bldr._frequency);
476+
}
453477

454478
static double combineCenters(double ths1, double ths0, double probs, long nrows) {
455479
//double center = (ths0 * n0 + ths1 * n1) / (n0 + n1);
@@ -474,26 +498,22 @@ private void mergeOneBin( ) {
474498
_yTreatment[ssx] += _yTreatment[ssx+1];
475499
_yControl[ssx] += _yControl[ssx+1];
476500
_frequency[ssx] += _frequency[ssx+1];
477-
int n = _n;
501+
int n = _nbinsUsed == 2 ? _nbinsUsed - ssx -1 : _nbinsUsed - ssx -2;
478502
// Slide over to crush the removed bin at index (ssx+1)
479-
System.arraycopy(_thresholds,ssx+2,_thresholds,ssx+1,n-ssx-2);
480-
System.arraycopy(_treatment,ssx+2,_treatment,ssx+1,n-ssx-2);
481-
System.arraycopy(_control,ssx+2,_control,ssx+1,n-ssx-2);
482-
System.arraycopy(_yTreatment,ssx+2,_yTreatment,ssx+1,n-ssx-2);
483-
System.arraycopy(_yControl,ssx+2,_yControl,ssx+1,n-ssx-2);
484-
System.arraycopy(_frequency,ssx+2,_frequency,ssx+1,n-ssx-2);
485-
_n--;
486-
_thresholds[_n] = _treatment[_n] = _control[_n] = _yTreatment[_n] = _yControl[_n] = _frequency[_n] = 0;
503+
System.arraycopy(_thresholds,ssx+2,_thresholds,ssx+1,n);
504+
System.arraycopy(_treatment,ssx+2,_treatment,ssx+1,n);
505+
System.arraycopy(_control,ssx+2,_control,ssx+1,n);
506+
System.arraycopy(_yTreatment,ssx+2,_yTreatment,ssx+1,n);
507+
System.arraycopy(_yControl,ssx+2,_yControl,ssx+1,n);
508+
System.arraycopy(_frequency,ssx+2,_frequency,ssx+1,n);
509+
_nbinsUsed--;
487510
_ssx = -1;
488511
}
489512

490-
// Find the pair of bins that when combined give the smallest increase in
491-
// squared error. Dups never increase squared error.
492-
//
493-
// I tried code for merging bins with keeping the bins balanced in size,
494-
// but this leads to bad errors if the probabilities are sorted. Also
495-
// tried the original: merge bins with the least distance between bin
496-
// centers. Same problem for sorted data.
513+
/**
514+
* Find the pair of bins that when combined give the smallest difference in thresholds
515+
* @return index of the bin where the threshold difference is the smallest
516+
*/
497517
private int findSmallest() {
498518
if( _ssx == -1 ) {
499519
_ssx = findSmallestImpl();
@@ -503,12 +523,10 @@ private int findSmallest() {
503523
}
504524

505525
private int findSmallestImpl() {
506-
if (_n == 1)
526+
if (_nbinsUsed == 1)
507527
return 0;
508-
// we couldn't find any bins to merge based on SE (the math can be producing Double.Infinity or Double.NaN)
509-
// revert to using a simple distance of the bin centers
510528
int minI = 0;
511-
long n = _n;
529+
long n = _nbinsUsed;
512530
double minDist = _thresholds[1] - _thresholds[0];
513531
for (int i = 1; i < n - 1; i++) {
514532
double dist = _thresholds[i + 1] - _thresholds[i];
@@ -521,25 +539,27 @@ private int findSmallestImpl() {
521539
}
522540

523541
private boolean dups() {
524-
long n = _n;
542+
long n = _nbinsUsed;
525543
for( int i=0; i<n-1; i++ ) {
526-
double derr = computeDeltaError(_thresholds[i+1],_frequency[i+1],_thresholds[i],_frequency[i]);
527-
if( derr == 0 ) { _ssx = i; return true; }
544+
double derr = computeDeltaError(_thresholds[i + 1], _frequency[i + 1], _thresholds[i], _frequency[i]);
545+
if (derr == 0) {
546+
_ssx = i;
547+
return true;
548+
}
528549
}
529550
return false;
530551
}
531552

532-
553+
/**
554+
* If thresholds vary by less than a float ULP, treat them as the same.
555+
* Parallel equation drawn from:
556+
* http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
557+
* @return delta error from two thresholds
558+
*/
533559
private double computeDeltaError(double ths1, double n1, double ths0, double n0 ) {
534-
// If thresholds vary by less than a float ULP, treat them as the same.
535-
// Some models only output predictions to within float accuracy (so a
536-
// variance here is junk), and also it's not statistically sane to have
537-
// a model which varies predictions by such a tiny change in thresholds.
538560
double delta = (float)ths1-(float)ths0;
539561
if (delta == 0)
540562
return 0;
541-
// Parallel equation drawn from:
542-
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
543563
return delta*delta*n0*n1 / (n0+n1);
544564
}
545565

@@ -555,7 +575,8 @@ private static double computeLinearInterpolation(double ths1, double ths0, doubl
555575

556576
private String toDebugString() {
557577
return "n =" +_n +
558-
"; nBins = " + _nbins +
578+
"; nbins = " + _nbins +
579+
"; nbinsUsed = " + _nbinsUsed +
559580
"; ths = " + Arrays.toString(_thresholds) +
560581
"; treatment = " + Arrays.toString(_treatment) +
561582
"; contribution = " + Arrays.toString(_control) +

0 commit comments

Comments
 (0)