Skip to content

Commit

Permalink
GH-15809: implement AIC and Loglikelihood calculation for GenericModel
Browse files Browse the repository at this point in the history
  • Loading branch information
syzonyuliia-h2o committed Feb 2, 2024
1 parent ba5aedf commit 04e9c92
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
48 changes: 48 additions & 0 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import hex.*;
import hex.genmodel.*;
import hex.genmodel.algos.glm.GlmMojoModel;
import hex.genmodel.algos.kmeans.KMeansMojoModel;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.descriptor.ModelDescriptorBuilder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.glm.GLMModel;
import hex.tree.isofor.ModelMetricsAnomaly;
import water.*;
import water.fvec.*;
Expand Down Expand Up @@ -78,6 +80,10 @@ private static MojoModel reconstructMojo(ByteVec mojoBytes) {
throw new IllegalStateException("Unreachable MOJO file: " + mojoBytes._key, e);
}
}
private Iced getParamByName(String name) {
return Arrays.stream(this._parms._modelParameters)
.filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value;
}

@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
Expand Down Expand Up @@ -131,6 +137,48 @@ protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String d
return predictScoreMojoImpl(fr, destination_key, j, computeMetrics);
} else
return super.predictScoreImpl(fr, adaptFrm, destination_key, j, computeMetrics, customMetricFunc);
// return super.predictScoreImpl(fr, adaptFrm, destination_key, j, true, customMetricFunc);
}


@Override
public double aic(double likelihood) {
// calculate negative loglikelihood specifically for GLM
if (!_algoName.equals("glm")) {
return 0;
} else {
double aic = -2 * likelihood + 2 * Arrays.stream(((GlmMojoModel) this.genModel()).getBeta()).filter(b -> b != 0).count();
System.out.println("Bettas for AIC: " + Arrays.stream(((GlmMojoModel) this.genModel()).getBeta()).filter(b -> b != 0).count());
System.out.println(Arrays.toString(((GlmMojoModel) this.genModel()).getBeta()));
System.out.println("Gen AIC: " + aic);
return aic;
}
}

@Override
public double likelihood(double w, double y, double[] f) {
// calculate negative loglikelihood specifically for GLM
if (w == 0 || !_algoName.equals("glm")) {
return 0;
} else {
// create GLM parameters instance
GLMModel.GLMParameters glmParameters = new GLMModel.GLMParameters(
GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()),
GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()),
Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Double.parseDouble(getParamByName("tweedie_variance_power").toString()),
Double.parseDouble(getParamByName("tweedie_link_power").toString()),
null,
Double.parseDouble(getParamByName("theta").toString())
);
// time-consuming calculation for the final scoring for GLM model
return glmParameters.likelihood(w, y, f);
}
}

PredictScoreResult predictScoreMojoImpl(Frame fr, String destination_key, Job<?> j, boolean computeMetrics) {
Expand Down
4 changes: 4 additions & 0 deletions h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,10 @@ public double likelihood(double w, double y, double[] f) {
return 0.0; // place holder. This function is overridden in GLM.
}

public double aic(double likelihood) {
return 0.0; // place holder. This function is overridden in GLM.
}

public ScoringInfo[] scoring_history() { return scoringInfo; }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ abstract class GlmMojoModelBase extends MojoModel {
super(columns, domains, responseColumn);
}

public double[] getBeta() {
return _beta;
}

void init() {
_versionSupportOffset = _mojo_version >= 1.1;
}
Expand Down

0 comments on commit 04e9c92

Please sign in to comment.