Skip to content

Commit

Permalink
GH-15809: enable loglikelihood and AIC calculation for multinomial fa…
Browse files Browse the repository at this point in the history
…mily
  • Loading branch information
syzonyuliia-h2o authored and wendycwong committed Feb 5, 2024
1 parent 455587d commit 50fb799
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 9 deletions.
26 changes: 24 additions & 2 deletions h2o-core/src/main/java/hex/ModelMetricsRegression.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@

public class ModelMetricsRegression extends ModelMetricsSupervised {
public final double _mean_residual_deviance;
public final double _AIC;
public final double _loglikelihood;
/**
* @return {@link #mean_residual_deviance()} for all algos except GLM, for which it means "total residual deviance".
**/
public double residual_deviance() { return _mean_residual_deviance; }
public double loglikelihood() { return _loglikelihood; }
public double aic() { return _AIC; }
@SuppressWarnings("unused")
public double mean_residual_deviance() { return _mean_residual_deviance; }
public final double _mean_absolute_error;
public double mae() { return _mean_absolute_error; }
public final double _root_mean_squared_log_error;
public double rmsle() { return _root_mean_squared_log_error; }
public ModelMetricsRegression(Model model, Frame frame, long nobs, double mse, double sigma, double mae,double rmsle, double meanResidualDeviance, CustomMetric customMetric) {
public ModelMetricsRegression(Model model, Frame frame, long nobs, double mse, double sigma, double mae,double rmsle,
double meanResidualDeviance, CustomMetric customMetric, double loglikelihood, double aic) {
super(model, frame, nobs, mse, null, sigma, customMetric);
_mean_residual_deviance = meanResidualDeviance;
_mean_absolute_error = mae;
_root_mean_squared_log_error = rmsle;
_loglikelihood = loglikelihood;
_AIC = aic;
}

public static ModelMetricsRegression getFromDKV(Model model, Frame frame) {
Expand All @@ -51,6 +58,8 @@ public String toString() {
}
sb.append(" mean absolute error: " + (float)_mean_absolute_error + "\n");
sb.append(" root mean squared log error: " + (float)_root_mean_squared_log_error + "\n");
sb.append(" loglikelihood: " + (float)_loglikelihood + "\n");
sb.append(" AIC: " + (float)_AIC + "\n");
return sb.toString();
}

Expand Down Expand Up @@ -117,6 +126,7 @@ public static class MetricBuilderRegression<T extends MetricBuilderRegression<T>
Distribution _dist;
double _abserror;
double _rmslerror;
protected double _loglikelihood;
public MetricBuilderRegression() {
super(1,null); //this will make _work = new float[2];
}
Expand Down Expand Up @@ -147,6 +157,10 @@ public MetricBuilderRegression(Distribution dist) {
_sumdeviance += _dist.deviance(w, yact[0], ds[0]);
}
}

if(m.getClass().toString().contains("Generic")) {
_loglikelihood += m.likelihood(w, yact[0], ds);
}

_count++;
_wcount += w;
Expand All @@ -160,6 +174,7 @@ public MetricBuilderRegression(Distribution dist) {
_sumdeviance += mb._sumdeviance;
_abserror += mb._abserror;
_rmslerror += mb._rmslerror;
_loglikelihood += mb._loglikelihood;
}

// Having computed a MetricBuilder, this method fills in a ModelMetrics
Expand All @@ -173,6 +188,8 @@ ModelMetricsRegression computeModelMetrics(Model m, Frame f, Frame adaptedFrame,
double mse = _sumsqe / _wcount;
double mae = _abserror/_wcount; //Mean Absolute Error
double rmsle = Math.sqrt(_rmslerror/_wcount); //Root Mean Squared Log Error
double loglikelihood = Double.NaN;
double aic = Double.NaN;
if (adaptedFrame ==null) adaptedFrame = f;
double meanResDeviance = 0;
if (m != null && m.isDistributionHuber()){
Expand All @@ -195,7 +212,12 @@ ModelMetricsRegression computeModelMetrics(Model m, Frame f, Frame adaptedFrame,
} else {
meanResDeviance = Double.NaN;
}
ModelMetricsRegression mm = new ModelMetricsRegression(m, f, _count, mse, weightedSigma(), mae, rmsle, meanResDeviance, _customMetric);
if(m.getClass().toString().contains("Generic")) {
loglikelihood = -1 * _loglikelihood ; // get likelihood from negative loglikelihood
aic = m.aic(loglikelihood);
}
ModelMetricsRegression mm = new ModelMetricsRegression(m, f, _count, mse, weightedSigma(), mae, rmsle,
meanResDeviance, _customMetric, loglikelihood, aic);
return mm;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ModelMetricsRegressionCoxPH extends ModelMetricsRegression {
public ModelMetricsRegressionCoxPH(Model model, Frame frame, long nobs, double mse, double sigma, double mae,
double rmsle, double meanResidualDeviance, CustomMetric customMetric,
double concordance, long concordant, long discordant, long tied_y) {
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric);
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric, 0, 0);

this._concordance = concordance;
this._concordant = concordant;
Expand Down
7 changes: 2 additions & 5 deletions h2o-core/src/main/java/hex/ModelMetricsRegressionGLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@ public class ModelMetricsRegressionGLM extends ModelMetricsRegression implements
public final long _residualDegressOfFreedom;
public final double _resDev;
public final double _nullDev;
public final double _AIC;
public final double _loglikelihood;


public ModelMetricsRegressionGLM(Model model, Frame frame, long nobs, double mse, double sigma,
double mae, double rmsle, double resDev, double meanResDev,
double nullDev, double aic, long nDof, long rDof,
CustomMetric customMetric, double loglikelihood) {
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResDev, customMetric);
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResDev, customMetric, loglikelihood, aic);
_resDev = resDev;
_nullDev = nullDev;
_AIC = aic;
_nullDegressOfFreedom = nDof;
_residualDegressOfFreedom = rDof;
_loglikelihood = loglikelihood;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public class ModelMetricsRegressionGeneric extends ModelMetricsRegression {

public ModelMetricsRegressionGeneric(Model model, Frame frame, long nobs, double mse, double sigma, double mae, double rmsle,
double meanResidualDeviance, CustomMetric customMetric, String description) {
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric);
super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric, 0, 0);
_description = description;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@ public class ModelMetricsRegressionGenericV3<I extends ModelMetricsRegressionGen
@API(help="The root mean squared log error for this scoring run.", direction=API.Direction.OUTPUT)
public double rmsle;

@API(help="The negative logarithmic likelihood for this scoring run.", direction=API.Direction.OUTPUT)
public double loglikelihood;

@API(help="The AIC for this scoring run.", direction=API.Direction.OUTPUT)
public double AIC;


@Override
public S fillFromImpl(I modelMetrics) {
super.fillFromImpl(modelMetrics);
mae = modelMetrics._mean_absolute_error;
rmsle = modelMetrics._root_mean_squared_log_error;
mean_residual_deviance = modelMetrics._mean_residual_deviance;
loglikelihood = modelMetrics.loglikelihood();
AIC = modelMetrics.aic();
return (S) this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@ public class ModelMetricsRegressionV3<I extends ModelMetricsRegression, S extend
@API(help="The root mean squared log error for this scoring run.", direction=API.Direction.OUTPUT)
public double rmsle;

@API(help="The negative logarithmic likelihood for this scoring run.", direction=API.Direction.OUTPUT)
public double loglikelihood;

@API(help="The AIC for this scoring run.", direction=API.Direction.OUTPUT)
public double AIC;


@Override
public S fillFromImpl(I modelMetrics) {
super.fillFromImpl(modelMetrics);
r2 = modelMetrics.r2();
mae = modelMetrics._mean_absolute_error;
rmsle = modelMetrics._root_mean_squared_log_error;
mean_residual_deviance = modelMetrics._mean_residual_deviance;
loglikelihood = modelMetrics.loglikelihood();
AIC = modelMetrics.aic();
return (S) this;
}
}

0 comments on commit 50fb799

Please sign in to comment.