Skip to content

Commit

Permalink
GH-15809: fixes loglikelihood and aic for glm generic model (#16025)
Browse files Browse the repository at this point in the history
* GH-15809: implement AIC and Loglikelihood calculation for GenericModel

* GH-15809: add AIC and Loglikelihood to ModelMetricsBinomial

* GH-15809: implement AIC and Loglikelihood calculations for ModelMetricsBinomial

* GH-15809: add AIC and Loglikelihood to output metrics

* GH-15809: update test to check AIC and Loglikelihood calculation for loaded model

* GH-15809: correct betas source

* GH-15809: implement AIC and loglikelihood calculation for multinomial generic glm

* GH-15809: minor aic retrieval fix

* GH-15809: enable loglikelihood and AIC calculation for multinomial family

* GH-15809: remove prints

* GH-15809: refactor

* GH-15809: add new parameter to the constructor, and add new constructor

* GH-15809: add dispersion_estimated parameter to GLM mojo

* GH-15809: update and fix tests

* GH-15809: fix metrics exposure in python

* GH-15809: fix parameters

* GH-15809: add null check

* GH-15809: fix tests

* GH-15809: fix R tests

* GH-15809: fix reading new parameter in MOJO load

* GH-15809: fix writing new parameter in MOJO load

* GH-15809: fix value

* GH-15809: fix comments

* GH-15809: fix printing metrics

* GH-15809: remove commented code

* GH-15809: assign NaN instead of 0 as placeholder value for Loglikelihood

* GH-15809: default dispersion estimation set to 1

* GH-15809: clean test

* GH-15809: fix aic check in test

* GH-15809: additionally fix aic check in test

* GH-15809: additionally fix aic check in test

* GH-15809: fit test - add default parameters

* Fixed test discrepancies.

* only return AIC and loglikelihood for glm models

* fixed AIC problem when model is not glm

* Incorporate Tomas F review.

* replace m != null && m.getClass().toString().contains(generic) with score4Generic

---------

Co-authored-by: syzonyuliia <[email protected]>
Co-authored-by: wendycwong <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent 8d4b331 commit a6bd451
Show file tree
Hide file tree
Showing 26 changed files with 301 additions and 46 deletions.
52 changes: 52 additions & 0 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import hex.*;
import hex.genmodel.*;
import hex.genmodel.algos.glm.GlmMojoModelBase;
import hex.genmodel.algos.kmeans.KMeansMojoModel;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.descriptor.ModelDescriptorBuilder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.glm.GLMModel;
import hex.tree.isofor.ModelMetricsAnomaly;
import water.*;
import water.fvec.*;
Expand Down Expand Up @@ -42,6 +44,7 @@ public class GenericModel extends Model<GenericModel, GenericModelParameters, Ge
*/
private final String _algoName;
private final GenModelSource<?> _genModelSource;
private GLMModel.GLMParameters _glmParameters;

/**
* Full constructor
Expand All @@ -56,6 +59,26 @@ public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, Gen
if (mojoModel._modelAttributes != null && mojoModel._modelAttributes.getModelParameters() != null) {
_parms._modelParameters = GenericModelParameters.convertParameters(mojoModel._modelAttributes.getModelParameters());
}
_glmParameters = null;
if(_algoName.toLowerCase().contains("glm")) {
GlmMojoModelBase glmModel = (GlmMojoModelBase) mojoModel;
// create GLM parameters instance
_glmParameters = new GLMModel.GLMParameters(
GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()),
GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()),
Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Double.parseDouble(getParamByName("tweedie_variance_power").toString()),
Double.parseDouble(getParamByName("tweedie_link_power").toString()),
null,
Double.parseDouble(getParamByName("theta").toString()),
glmModel.getDispersionEstimated()
);
}
}

public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, GenericModelOutput output,
Expand Down Expand Up @@ -133,6 +156,35 @@ protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String d
return super.predictScoreImpl(fr, adaptFrm, destination_key, j, computeMetrics, customMetricFunc);
}

private Iced getParamByName(String name) {
return Arrays.stream(this._parms._modelParameters)
.filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value;
}

@Override
public double aic(double likelihood) {
// calculate negative loglikelihood specifically for GLM
if (!_algoName.equals("glm")) {
return Double.NaN;
} else {
long betasCount = Arrays.stream(((GlmMojoModelBase) this.genModel()).getBeta()).filter(b -> b != 0).count();
return -2 * likelihood + 2 * betasCount;
}
}

@Override
public double likelihood(double w, double y, double[] f) {
// calculate negative loglikelihood specifically for GLM
if(!_algoName.equals("glm")) {
return Double.NaN;
} else if (w == 0) {
return 0;
} else {
// time-consuming calculation for the final scoring for GLM model
return _glmParameters.likelihood(w, y, f);
}
}

PredictScoreResult predictScoreMojoImpl(Frame fr, String destination_key, Job<?> j, boolean computeMetrics) {
GenModel model = genModel();
String[] names = model.getOutputNames();
Expand Down
7 changes: 6 additions & 1 deletion h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,11 @@ public GLMParameters(Family f, Link l, double [] lambda, double [] alpha, double

public GLMParameters(Family f, Link l, double [] lambda, double [] alpha, double twVar, double twLnk,
String[] interactions, double theta){
this(f,l,lambda,alpha,twVar,twLnk,interactions, theta, Double.NaN);
}

public GLMParameters(Family f, Link l, double [] lambda, double [] alpha, double twVar, double twLnk,
String[] interactions, double theta, double dispersion_estimated){
this._lambda = lambda;
this._alpha = alpha;
this._tweedie_variance_power = twVar;
Expand All @@ -736,7 +741,7 @@ public GLMParameters(Family f, Link l, double [] lambda, double [] alpha, double
_link = l;
this._theta=theta;
this._invTheta = 1.0/theta;
this._dispersion_estimated = _init_dispersion_parameter;
this._dispersion_estimated = Double.isNaN(dispersion_estimated) ? _init_dispersion_parameter : dispersion_estimated;
}

public final double variance(double mu){
Expand Down
2 changes: 2 additions & 0 deletions h2o-algos/src/main/java/hex/glm/GLMMojoWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ protected void writeModelData() throws IOException {

if (GLMModel.GLMParameters.Family.tweedie.equals(model._parms._family))
writekv("tweedie_link_power", model._parms._tweedie_link_power);

writekv("dispersion_estimated", (model._parms._compute_p_values ? model._parms._dispersion_estimated : 1.0));
}

}
6 changes: 5 additions & 1 deletion h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,11 @@ public double deviance(double w, double y, double f) {
}

public double likelihood(double w, double y, double[] f) {
return 0.0; // place holder. This function is overridden in GLM.
return Double.NaN; // placeholder. This function is overridden in GLM and GenericModel.
}

public double aic(double likelihood) {
return Double.NaN; // placeholder. This function is overridden in GenericModel.
}

public ScoringInfo[] scoring_history() { return scoringInfo; }
Expand Down
34 changes: 31 additions & 3 deletions h2o-core/src/main/java/hex/ModelMetricsBinomial.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

import java.util.Arrays;
Expand All @@ -19,19 +18,30 @@
public class ModelMetricsBinomial extends ModelMetricsSupervised {
public final AUC2 _auc;
public final double _logloss;
public final double _loglikelihood;
public final double _aic;
public double _mean_per_class_error;
public final GainsLift _gainsLift;

public ModelMetricsBinomial(Model model, Frame frame, long nobs, double mse, String[] domain,
double sigma, AUC2 auc, double logloss, GainsLift gainsLift,
double sigma, AUC2 auc, double logloss, double loglikelihood, double aic, GainsLift gainsLift,
CustomMetric customMetric) {
super(model, frame, nobs, mse, domain, sigma, customMetric);
_auc = auc;
_logloss = logloss;
_loglikelihood = loglikelihood;
_aic = aic;
_gainsLift = gainsLift;
_mean_per_class_error = cm() == null ? Double.NaN : cm().mean_per_class_error();
}

public ModelMetricsBinomial(Model model, Frame frame, long nobs, double mse, String[] domain,
double sigma, AUC2 auc, double logloss, GainsLift gainsLift,
CustomMetric customMetric) {
this(model, frame, nobs, mse, domain, sigma, auc, logloss, Double.NaN, Double.NaN,
gainsLift, customMetric);
}

public static ModelMetricsBinomial getFromDKV(Model model, Frame frame) {
ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
if( !(mm instanceof ModelMetricsBinomial) )
Expand All @@ -49,6 +59,8 @@ public String toString() {
sb.append(" pr_auc: " + (float)_auc.pr_auc() + "\n");
}
sb.append(" logloss: " + (float)_logloss + "\n");
sb.append(" loglikelihood: " + (float)_loglikelihood + "\n");
sb.append(" AIC: " + (float)_aic + "\n");
sb.append(" mean_per_class_error: " + (float)_mean_per_class_error + "\n");
sb.append(" default threshold: " + (_auc == null ? 0.5 : (float)_auc.defaultThreshold()) + "\n");
if (cm() != null) sb.append(" CM: " + cm().toASCII());
Expand All @@ -57,6 +69,8 @@ public String toString() {
}

public double logloss() { return _logloss; }
public double loglikelihood() { return _loglikelihood; }
public double aic() { return _aic; }
public double mean_per_class_error() { return _mean_per_class_error; }
@Override public AUC2 auc_obj() { return _auc; }
@Override public ConfusionMatrix cm() {
Expand Down Expand Up @@ -161,6 +175,7 @@ private static class BinomialMetrics extends MRTask<BinomialMetrics> {

public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>> extends MetricBuilderSupervised<T> {
protected double _logloss;
protected double _loglikelihood;
protected AUC2.AUCBuilder _auc;

public MetricBuilderBinomial( String[] domain ) { super(2,domain); _auc = new AUC2.AUCBuilder(AUC2.NBINS); }
Expand All @@ -177,6 +192,7 @@ public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>> ex
if(w == 0 || Double.isNaN(w)) return ds;
int iact = (int)yact[0];
boolean quasibinomial = (m!=null && m._parms._distribution == DistributionFamily.quasibinomial);
boolean score4Generic = m != null && m.getClass().toString().contains("Generic");
if (quasibinomial) {
if (yact[0] != 0)
iact = _domain[0].equals(String.valueOf((int) yact[0])) ? 0 : 1; // actual response index needed for confusion matrix, AUC, etc.
Expand All @@ -197,6 +213,11 @@ public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>> ex
// Compute log loss
_logloss += w * MathUtils.logloss(err);
}

if(score4Generic) { // only perform for generic model, will increase run time for training if performs
_loglikelihood += m.likelihood(w, yact[0], ds);
}

_count++;
_wcount += w;
assert !Double.isNaN(_sumsqe);
Expand All @@ -207,6 +228,7 @@ public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>> ex
@Override public void reduce( T mb ) {
super.reduce(mb); // sumseq, count
_logloss += mb._logloss;
_loglikelihood += mb._loglikelihood;
_auc.reduce(mb._auc);
}

Expand Down Expand Up @@ -256,18 +278,24 @@ private ModelMetrics makeModelMetrics(final Model m, final Frame f, final Frame

private ModelMetrics makeModelMetrics(Model m, Frame f, GainsLift gl) {
double mse = Double.NaN;
double loglikelihood = Double.NaN;
double aic = Double.NaN;
double logloss = Double.NaN;
double sigma = Double.NaN;
final AUC2 auc;
if (_wcount > 0) {
sigma = weightedSigma();
mse = _sumsqe / _wcount;
logloss = _logloss / _wcount;
if(m != null && m.getClass().toString().contains("Generic")) {
loglikelihood = -1 * _loglikelihood ; // get likelihood from negative loglikelihood
aic = m.aic(loglikelihood);
}
auc = new AUC2(_auc);
} else {
auc = new AUC2();
}
ModelMetricsBinomial mm = new ModelMetricsBinomial(m, f, _count, mse, _domain, sigma, auc, logloss, gl, _customMetric);
ModelMetricsBinomial mm = new ModelMetricsBinomial(m, f, _count, mse, _domain, sigma, auc, logloss, loglikelihood, aic, gl, _customMetric);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
Expand Down
4 changes: 2 additions & 2 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialGLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public ModelMetricsBinomialGLM(Model model, Frame frame, long nobs, double mse,
double sigma, AUC2 auc, double logloss, double resDev, double nullDev,
double aic, long nDof, long rDof, GainsLift gainsLift,
CustomMetric customMetric, double loglikelihood) {
super(model, frame, nobs, mse, domain, sigma, auc, logloss, gainsLift, customMetric);
super(model, frame, nobs, mse, domain, sigma, auc, logloss, loglikelihood, aic, gainsLift, customMetric);
_resDev = resDev;
_nullDev = nullDev;
_AIC = aic;
Expand Down Expand Up @@ -70,7 +70,7 @@ public ModelMetricsMultinomialGLM(Model model, Frame frame, long nobs, double ms
double sigma, ConfusionMatrix cm, float [] hr, double logloss,
double resDev, double nullDev, double aic, long nDof, long rDof,
MultinomialAUC auc, CustomMetric customMetric, double loglikelihood) {
super(model, frame, nobs, mse, domain, sigma, cm, hr, logloss, auc, customMetric);
super(model, frame, nobs, mse, domain, sigma, cm, hr, logloss, loglikelihood, aic, auc, customMetric);
_resDev = resDev;
_nullDev = nullDev;
_AIC = aic;
Expand Down
35 changes: 33 additions & 2 deletions h2o-core/src/main/java/hex/ModelMetricsMultinomial.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,38 @@ public class ModelMetricsMultinomial extends ModelMetricsSupervised {
public final float[] _hit_ratios; // Hit ratios
public final ConfusionMatrix _cm;
public final double _logloss;
public final double _loglikelihood;
public final double _aic;
public double _mean_per_class_error;
public MultinomialAUC _auc;

public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss, MultinomialAUC auc, CustomMetric customMetric) {
public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma,
ConfusionMatrix cm, float[] hr, double logloss, double loglikelihood, double aic,
MultinomialAUC auc, CustomMetric customMetric) {
super(model, frame, nobs, mse, domain, sigma, customMetric);
_cm = cm;
_hit_ratios = hr;
_logloss = logloss;
_loglikelihood = loglikelihood;
_aic = aic;
_mean_per_class_error = cm==null || cm.tooLarge() ? Double.NaN : cm.mean_per_class_error();
_auc = auc;
}

public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma,
ConfusionMatrix cm, float[] hr, double logloss, MultinomialAUC auc,
CustomMetric customMetric) {
this(model, frame, nobs, mse, domain, sigma, cm, hr, logloss, Double.NaN, Double.NaN, auc, customMetric);

}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append(" logloss: " + (float)_logloss + "\n");
sb.append(" loglikelihood: " + (float)_loglikelihood + "\n");
sb.append(" AIC: " + (float)_aic + "\n");
sb.append(" mean_per_class_error: " + (float)_mean_per_class_error + "\n");
sb.append(" hit ratios: " + Arrays.toString(_hit_ratios) + "\n");
sb.append(" AUC: "+auc()+ "\n");
Expand All @@ -59,6 +74,8 @@ public String toString() {
}

public double logloss() { return _logloss; }
public double loglikelihood() { return _loglikelihood; }
public double aic() { return _aic; }
public double mean_per_class_error() { return _mean_per_class_error; }
@Override public ConfusionMatrix cm() { return _cm; }
@Override public float[] hr() { return _hit_ratios; }
Expand Down Expand Up @@ -235,6 +252,7 @@ public static class MetricBuilderMultinomial<T extends MetricBuilderMultinomial<
double[/*K*/] _hits; // the number of hits for hitratio, length: K
int _K; // TODO: Let user set K
double _logloss;
protected double _loglikelihood;
boolean _calculateAuc;
AUC2.AUCBuilder[/*nclasses*/][/*nclasses*/] _ovoAucs;
AUC2.AUCBuilder[/*nclasses*/] _ovrAucs;
Expand Down Expand Up @@ -276,6 +294,7 @@ public MetricBuilderMultinomial( int nclasses, String[] domain, MultinomialAucTy
if(ArrayUtils.hasNaNs(ds)) return ds;
if(w == 0 || Double.isNaN(w)) return ds;
final int iact = (int)yact[0];
boolean score4Generic = m != null && m.getClass().toString().contains("Generic");
_count++;
_wcount += w;
_wY += w*iact;
Expand All @@ -302,6 +321,11 @@ public MetricBuilderMultinomial( int nclasses, String[] domain, MultinomialAucTy
if(_calculateAuc) {
calculateAucsPerRow(ds, iact, w);
}


if(score4Generic) { // only perform for generic model, will increase run time for training if perform
_loglikelihood += m.likelihood(w, yact[0], ds);
}
return ds; // Flow coding
}

Expand Down Expand Up @@ -335,6 +359,7 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
ArrayUtils.add(_cm, mb._cm);
_hits = ArrayUtils.add(_hits, mb._hits);
_logloss += mb._logloss;
_loglikelihood += mb._loglikelihood;
if(_calculateAuc) {
for (int i = 0; i < _ovoAucs.length; i++) {
_ovrAucs[i].reduce(mb._ovrAucs[i]);
Expand All @@ -350,6 +375,8 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
@Override public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
double mse = Double.NaN;
double logloss = Double.NaN;
double loglikelihood = Double.NaN;
double aic = Double.NaN;
float[] hr = new float[_K];
ConfusionMatrix cm = new ConfusionMatrix(_cm, _domain);
double sigma = weightedSigma();
Expand All @@ -360,10 +387,14 @@ private void calculateAucsPerRow(double ds[], int iact, double w){
}
mse = _sumsqe / _wcount;
logloss = _logloss / _wcount;
if(m != null && m.getClass().toString().contains("Generic")) {
loglikelihood = -1 * _loglikelihood ; // get likelihood from negative loglikelihood
aic = m.aic(loglikelihood);
}
}
MultinomialAUC auc = new MultinomialAUC(_ovrAucs,_ovoAucs, _domain, _wcount == 0, _aucType);
ModelMetricsMultinomial mm = new ModelMetricsMultinomial(m, f, _count, mse, _domain, sigma, cm,
hr, logloss, auc, _customMetric);
hr, logloss, loglikelihood, aic, auc, _customMetric);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
Expand Down
10 changes: 10 additions & 0 deletions h2o-core/src/main/java/hex/ModelMetricsMultinomialGLMGeneric.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ public class ModelMetricsMultinomialGLMGeneric extends ModelMetricsMultinomialGe
public final double _loglikelihood;
public final TwoDimTable _coefficients_table;

public ModelMetricsMultinomialGLMGeneric(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma,
TwoDimTable confusion_matrix, TwoDimTable hit_ratio_table, double logloss, CustomMetric customMetric,
double mean_per_class_error, long nullDegreesOfFreedom, long residualDegreesOfFreedom,
double resDev, double nullDev, TwoDimTable coefficients_table, double r2,
TwoDimTable multinomial_auc_table, TwoDimTable multinomial_aucpr_table, MultinomialAucType type,
final String description) {
this(model, frame, nobs, mse, domain, sigma, confusion_matrix, hit_ratio_table, logloss, customMetric,
mean_per_class_error, nullDegreesOfFreedom, residualDegreesOfFreedom, resDev, nullDev, Double.NaN,
coefficients_table, r2, multinomial_auc_table, multinomial_aucpr_table, type, description, Double.NaN);
}
public ModelMetricsMultinomialGLMGeneric(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma,
TwoDimTable confusion_matrix, TwoDimTable hit_ratio_table, double logloss, CustomMetric customMetric,
double mean_per_class_error, long nullDegreesOfFreedom, long residualDegreesOfFreedom,
Expand Down
Loading

0 comments on commit a6bd451

Please sign in to comment.