Skip to content

Commit

Permalink
GH-15809: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
syzonyuliia-h2o committed Jan 26, 2024
1 parent d5b7fc3 commit 78a60a8
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class GenericModel extends Model<GenericModel, GenericModelParameters, Ge
*/
private final String _algoName;
private final GenModelSource<?> _genModelSource;
private GLMModel.GLMParameters _glmParameters;

/**
* Full constructor
Expand All @@ -58,6 +59,26 @@ public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, Gen
if (mojoModel._modelAttributes != null && mojoModel._modelAttributes.getModelParameters() != null) {
_parms._modelParameters = GenericModelParameters.convertParameters(mojoModel._modelAttributes.getModelParameters());
}
_glmParameters = null;
if(_algoName.toLowerCase().contains("glm")) {
GlmMojoModelBase glmModel = (GlmMojoModelBase) mojoModel;
// create GLM parameters instance
_glmParameters = new GLMModel.GLMParameters(
GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()),
GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()),
Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Double.parseDouble(getParamByName("tweedie_variance_power").toString()),
Double.parseDouble(getParamByName("tweedie_link_power").toString()),
null,
Double.parseDouble(getParamByName("theta").toString()),
glmModel.getDispersionEstimated()
);
}
}

public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, GenericModelOutput output,
Expand All @@ -80,10 +101,6 @@ private static MojoModel reconstructMojo(ByteVec mojoBytes) {
throw new IllegalStateException("Unreachable MOJO file: " + mojoBytes._key, e);
}
}
private Iced getParamByName(String name) {
return Arrays.stream(this._parms._modelParameters)
.filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value;
}

@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
Expand Down Expand Up @@ -140,6 +157,10 @@ protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String d
// return super.predictScoreImpl(fr, adaptFrm, destination_key, j, true, customMetricFunc);
}

private Iced getParamByName(String name) {
return Arrays.stream(this._parms._modelParameters)
.filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value;
}

@Override
public double aic(double likelihood) {
Expand All @@ -158,23 +179,8 @@ public double likelihood(double w, double y, double[] f) {
if (w == 0 || !_algoName.equals("glm")) {
return 0;
} else {
// create GLM parameters instance
GLMModel.GLMParameters glmParameters = new GLMModel.GLMParameters(
GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()),
GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()),
Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "")
.replaceAll("\\]", "").split(",\\s*"))
.mapToDouble(Double::parseDouble).toArray(),
Double.parseDouble(getParamByName("tweedie_variance_power").toString()),
Double.parseDouble(getParamByName("tweedie_link_power").toString()),
null,
Double.parseDouble(getParamByName("theta").toString())
);
// time-consuming calculation for the final scoring for GLM model
return glmParameters.likelihood(w, y, f);
return _glmParameters.likelihood(w, y, f);
}
}

Expand Down

0 comments on commit 78a60a8

Please sign in to comment.