diff --git a/h2o-algos/src/main/java/hex/generic/GenericModel.java b/h2o-algos/src/main/java/hex/generic/GenericModel.java index cf0841a625a1..feaaf4836422 100644 --- a/h2o-algos/src/main/java/hex/generic/GenericModel.java +++ b/h2o-algos/src/main/java/hex/generic/GenericModel.java @@ -44,6 +44,7 @@ public class GenericModel extends Model _genModelSource; + private GLMModel.GLMParameters _glmParameters; /** * Full constructor @@ -58,6 +59,26 @@ public GenericModel(Key selfKey, GenericModelParameters parms, Gen if (mojoModel._modelAttributes != null && mojoModel._modelAttributes.getModelParameters() != null) { _parms._modelParameters = GenericModelParameters.convertParameters(mojoModel._modelAttributes.getModelParameters()); } + _glmParameters = null; + if(_algoName.toLowerCase().contains("glm")) { + GlmMojoModelBase glmModel = (GlmMojoModelBase) mojoModel; + // create GLM parameters instance + _glmParameters = new GLMModel.GLMParameters( + GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()), + GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()), + Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "") + .replaceAll("\\]", "").split(",\\s*")) + .mapToDouble(Double::parseDouble).toArray(), + Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "") + .replaceAll("\\]", "").split(",\\s*")) + .mapToDouble(Double::parseDouble).toArray(), + Double.parseDouble(getParamByName("tweedie_variance_power").toString()), + Double.parseDouble(getParamByName("tweedie_link_power").toString()), + null, + Double.parseDouble(getParamByName("theta").toString()), + glmModel.getDispersionEstimated() + ); + } } public GenericModel(Key selfKey, GenericModelParameters parms, GenericModelOutput output, @@ -80,10 +101,6 @@ private static MojoModel reconstructMojo(ByteVec mojoBytes) { throw new IllegalStateException("Unreachable MOJO file: " + mojoBytes._key, e); } } - private Iced getParamByName(String name) { - return Arrays.stream(this._parms._modelParameters) - .filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value; - } @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { @@ -140,6 +157,10 @@ protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String d // return super.predictScoreImpl(fr, adaptFrm, destination_key, j, true, customMetricFunc); } + private Iced getParamByName(String name) { + return Arrays.stream(this._parms._modelParameters) + .filter(p -> Objects.equals(p.name, name)).findAny().get().actual_value; + } @Override public double aic(double likelihood) { @@ -158,23 +179,8 @@ public double likelihood(double w, double y, double[] f) { if (w == 0 || !_algoName.equals("glm")) { return 0; } else { - // create GLM parameters instance - GLMModel.GLMParameters glmParameters = new GLMModel.GLMParameters( - GLMModel.GLMParameters.Family.valueOf(getParamByName("family").toString()), - GLMModel.GLMParameters.Link.valueOf(getParamByName("link").toString()), - Arrays.stream(getParamByName("lambda").toString().trim().replaceAll("\\[", "") - .replaceAll("\\]", "").split(",\\s*")) - .mapToDouble(Double::parseDouble).toArray(), - Arrays.stream(getParamByName("alpha").toString().trim().replaceAll("\\[", "") - .replaceAll("\\]", "").split(",\\s*")) - .mapToDouble(Double::parseDouble).toArray(), - Double.parseDouble(getParamByName("tweedie_variance_power").toString()), - Double.parseDouble(getParamByName("tweedie_link_power").toString()), - null, - Double.parseDouble(getParamByName("theta").toString()) - ); // time-consuming calculation for the final scoring for GLM model - return glmParameters.likelihood(w, y, f); + return _glmParameters.likelihood(w, y, f); } }