Skip to content

Commit

Permalink
GH-15809: implement AIC and loglikelihood calculation for multinomial…
Browse files Browse the repository at this point in the history
… generic glm
  • Loading branch information
syzonyuliia-h2o authored and wendycwong committed Feb 5, 2024
1 parent 7b4d93f commit c20c2d5
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
2 changes: 1 addition & 1 deletion h2o-core/src/main/java/hex/ModelMetricsBinomialGLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public ModelMetricsMultinomialGLM(Model model, Frame frame, long nobs, double ms
double sigma, ConfusionMatrix cm, float [] hr, double logloss,
double resDev, double nullDev, double aic, long nDof, long rDof,
MultinomialAUC auc, CustomMetric customMetric, double loglikelihood) {
super(model, frame, nobs, mse, domain, sigma, cm, hr, logloss, auc, customMetric);
super(model, frame, nobs, mse, domain, sigma, cm, hr, logloss, loglikelihood, aic, auc, customMetric);
_resDev = resDev;
_nullDev = nullDev;
_AIC = aic;
Expand Down
29 changes: 27 additions & 2 deletions h2o-core/src/main/java/hex/ModelMetricsMultinomial.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ public class ModelMetricsMultinomial extends ModelMetricsSupervised {
public final float[] _hit_ratios; // Hit ratios
public final ConfusionMatrix _cm;
public final double _logloss;
public final double _loglikelihood;
public final double _aic;
public double _mean_per_class_error;
public MultinomialAUC _auc;

public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss, MultinomialAUC auc, CustomMetric customMetric) {
public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma,
ConfusionMatrix cm, float[] hr, double logloss, double loglikelihood, double aic,
MultinomialAUC auc, CustomMetric customMetric) {
super(model, frame, nobs, mse, domain, sigma, customMetric);
_cm = cm;
_hit_ratios = hr;
_logloss = logloss;
_loglikelihood = loglikelihood;
_aic = aic;
_mean_per_class_error = cm==null || cm.tooLarge() ? Double.NaN : cm.mean_per_class_error();
_auc = auc;
}
Expand All @@ -35,6 +41,8 @@ public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append(" logloss: " + (float)_logloss + "\n");
sb.append(" loglikelihood: " + (float)_loglikelihood + "\n");
sb.append(" AIC: " + (float)_aic + "\n");
sb.append(" mean_per_class_error: " + (float)_mean_per_class_error + "\n");
sb.append(" hit ratios: " + Arrays.toString(_hit_ratios) + "\n");
sb.append(" AUC: "+auc()+ "\n");
Expand All @@ -59,6 +67,8 @@ public String toString() {
}

public double logloss() { return _logloss; }
public double loglikelihood() { return _loglikelihood; }
public double aic() { return _aic; }
public double mean_per_class_error() { return _mean_per_class_error; }
@Override public ConfusionMatrix cm() { return _cm; }
@Override public float[] hr() { return _hit_ratios; }
Expand Down Expand Up @@ -235,6 +245,7 @@ public static class MetricBuilderMultinomial<T extends MetricBuilderMultinomial<
double[/*K*/] _hits; // the number of hits for hitratio, length: K
int _K; // TODO: Let user set K
double _logloss;
protected double _loglikelihood;
boolean _calculateAuc;
AUC2.AUCBuilder[/*nclasses*/][/*nclasses*/] _ovoAucs;
AUC2.AUCBuilder[/*nclasses*/] _ovrAucs;
Expand Down Expand Up @@ -302,6 +313,13 @@ public MetricBuilderMultinomial( int nclasses, String[] domain, MultinomialAucTy
if(_calculateAuc) {
calculateAucsPerRow(ds, iact, w);
}


if(m.getClass().toString().contains("Generic")) {
_loglikelihood += m.likelihood(w, yact[0], ds);
System.out.println("_logloss: " + _logloss);
System.out.println("_loglikelihood: " + _loglikelihood);
}
return ds; // Flow coding
}

Expand Down Expand Up @@ -335,6 +353,7 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
ArrayUtils.add(_cm, mb._cm);
_hits = ArrayUtils.add(_hits, mb._hits);
_logloss += mb._logloss;
_loglikelihood += mb._loglikelihood;
if(_calculateAuc) {
for (int i = 0; i < _ovoAucs.length; i++) {
_ovrAucs[i].reduce(mb._ovrAucs[i]);
Expand All @@ -350,6 +369,8 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
@Override public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
double mse = Double.NaN;
double logloss = Double.NaN;
double loglikelihood = Double.NaN;
double aic = Double.NaN;
float[] hr = new float[_K];
ConfusionMatrix cm = new ConfusionMatrix(_cm, _domain);
double sigma = weightedSigma();
Expand All @@ -360,10 +381,14 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
}
mse = _sumsqe / _wcount;
logloss = _logloss / _wcount;
if(m.getClass().toString().contains("Generic")) {
loglikelihood = -1 * _loglikelihood ; // get likelihood from negative loglikelihood
aic = m.aic(loglikelihood);
}
}
MultinomialAUC auc = new MultinomialAUC(_ovrAucs,_ovoAucs, _domain, _wcount == 0, _aucType);
ModelMetricsMultinomial mm = new ModelMetricsMultinomial(m, f, _count, mse, _domain, sigma, cm,
hr, logloss, auc, _customMetric);
hr, logloss, loglikelihood, aic, auc, _customMetric);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public ModelMetricsMultinomialGeneric(Model model, Frame frame, long nobs, doubl
TwoDimTable confusion_matrix, TwoDimTable hit_ratio_table, double logloss, CustomMetric customMetric,
double mean_per_class_error, double r2, TwoDimTable multinomial_auc_table, TwoDimTable multinomial_aucpr_table,
MultinomialAucType type, final String description) {
super(model, frame, nobs, mse, domain, sigma, null, null, logloss, null, customMetric);
super(model, frame, nobs, mse, domain, sigma, null, null, logloss, 0, 0, null, customMetric);
_confusion_matrix_table = confusion_matrix;
_hit_ratio_table = hit_ratio_table;
_auc = new MultinomialAUC(multinomial_auc_table, multinomial_aucpr_table, domain, type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ public class ModelMetricsMultinomialGenericV3<I extends ModelMetricsMultinomialG
public S fillFromImpl(I modelMetrics) {
super.fillFromImpl(modelMetrics);
logloss = modelMetrics._logloss;
loglikelihood = modelMetrics.loglikelihood();
AIC = modelMetrics.aic();

r2 = modelMetrics.r2();

if (modelMetrics._hit_ratio_table != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ public class ModelMetricsMultinomialV3<I extends ModelMetricsMultinomial, S exte
@API(help="The logarithmic loss for this scoring run.", direction=API.Direction.OUTPUT)
public double logloss;

@API(help="The negative logarithmic likelihood for this scoring run.", direction=API.Direction.OUTPUT)
public double loglikelihood;

@API(help="The AIC for this scoring run.", direction=API.Direction.OUTPUT)
public double AIC;

@API(help="The mean misclassification error per class.", direction=API.Direction.OUTPUT)
public double mean_per_class_error;

Expand All @@ -43,6 +49,9 @@ public class ModelMetricsMultinomialV3<I extends ModelMetricsMultinomial, S exte
public S fillFromImpl(I modelMetrics) {
super.fillFromImpl(modelMetrics);
logloss = modelMetrics.logloss();
loglikelihood = modelMetrics.loglikelihood();
AIC = modelMetrics.aic();

r2 = modelMetrics.r2();

if (modelMetrics._hit_ratios != null) {
Expand Down

0 comments on commit c20c2d5

Please sign in to comment.