diff --git a/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java b/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java index ccd0faa36e1d..4731bcfd2fcd 100644 --- a/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java +++ b/h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java @@ -1,7 +1,14 @@ package hex.Infogram; -import hex.*; -import water.*; +import hex.Model; +import hex.ModelBuilder; +import hex.ModelBuilderHelper; +import hex.ModelCategory; +import hex.genmodel.utils.DistributionFamily; +import water.DKV; +import water.H2O; +import water.Key; +import water.Scope; import water.exceptions.H2OModelBuilderIllegalArgumentException; import water.fvec.Frame; import water.util.ArrayUtils; @@ -9,7 +16,7 @@ import java.util.*; import java.util.stream.IntStream; -import hex.genmodel.utils.DistributionFamily; + import static hex.Infogram.InfogramModel.InfogramModelOutput.sortCMIRel; import static hex.Infogram.InfogramModel.InfogramParameters.Algorithm.AUTO; import static hex.Infogram.InfogramModel.InfogramParameters.Algorithm.gbm; @@ -182,61 +189,29 @@ private void validateInfoGramParameters() { _buildCore = _parms._protected_columns == null; if (_buildCore) { - if (_parms._net_information_threshold == -1) { // not set - _parms._cmi_threshold = 0.1; - _parms._net_information_threshold = 0.1; - } else if (_parms._net_information_threshold > 1 || _parms._net_information_threshold < 0) { + if (_parms._net_information_threshold > 1 || _parms._net_information_threshold < 0) { error("net_information_threshold", " should be set to be between 0 and 1."); } else { _parms._cmi_threshold = _parms._net_information_threshold; } - if (_parms._total_information_threshold == -1) { // not set - _parms._relevance_threshold = 0.1; - _parms._total_information_threshold = 0.1; - } else if (_parms._total_information_threshold < 0 || _parms._total_information_threshold > 1) { + if (_parms._total_information_threshold < 0 || _parms._total_information_threshold > 1) { error("total_information_threshold", " should be set to be between 0 and 1."); } else { _parms._relevance_threshold = _parms._total_information_threshold; } - - if (_parms._safety_index_threshold != -1) { - warn("safety_index_threshold", "Should not set safety_index_threshold for core infogram " + - "runs. Set net_information_threshold instead. Using default of 0.1 if not set"); - } - - if (_parms._relevance_index_threshold != -1) { - warn("relevance_index_threshold", "Should not set relevance_index_threshold for core " + - "infogram runs. Set total_information_threshold instead. Using default of 0.1 if not set"); - } } else { // fair infogram - if (_parms._safety_index_threshold == -1) { - _parms._cmi_threshold = 0.1; - _parms._safety_index_threshold = 0.1; - } else if (_parms._safety_index_threshold < 0 || _parms._safety_index_threshold > 1) { + if (_parms._safety_index_threshold < 0 || _parms._safety_index_threshold > 1) { error("safety_index_threshold", " should be set to be between 0 and 1."); } else { _parms._cmi_threshold = _parms._safety_index_threshold; } - if (_parms._relevance_index_threshold == -1) { - _parms._relevance_threshold = 0.1; - _parms._relevance_index_threshold = 0.1; - } else if (_parms._relevance_index_threshold < 0 || _parms._relevance_index_threshold > 1) { + if (_parms._relevance_index_threshold < 0 || _parms._relevance_index_threshold > 1) { error("relevance_index_threshold", " should be set to be between 0 and 1."); } else { _parms._relevance_threshold = _parms._relevance_index_threshold; } - - if (_parms._net_information_threshold != -1) { - warn("net_information_threshold", "Should not set net_information_threshold for fair " + - "infogram runs, set safety_index_threshold instead. Using default of 0.1 if not set"); - } - if (_parms._total_information_threshold != -1) { - warn("total_information_threshold", "Should not set total_information_threshold for fair" + - " infogram runs, set relevance_index_threshold instead. Using default of 0.1 if not set"); - } - if (AUTO.equals(_parms._algorithm)) _parms._algorithm = gbm; } diff --git a/h2o-admissibleml/src/main/java/hex/Infogram/InfogramModel.java b/h2o-admissibleml/src/main/java/hex/Infogram/InfogramModel.java index 9d8a09c127ec..bd723b14a186 100644 --- a/h2o-admissibleml/src/main/java/hex/Infogram/InfogramModel.java +++ b/h2o-admissibleml/src/main/java/hex/Infogram/InfogramModel.java @@ -5,11 +5,10 @@ import hex.*; import hex.genmodel.utils.DistributionFamily; import hex.glm.GLMModel; -import hex.schemas.*; +import hex.schemas.InfogramV3; import water.*; import water.fvec.Frame; import water.udf.CFuncRef; -import water.util.TwoDimTable; import java.lang.reflect.Field; import java.util.*; @@ -55,10 +54,10 @@ public static class InfogramParameters extends Model.Parameters { public String[] _protected_columns = null; // store features to be excluded from final model public double _cmi_threshold = 0.1; // default set by Deep public double _relevance_threshold = 0.1; // default set by Deep - public double _total_information_threshold = -1; // relevance threshold for core infogram - public double _net_information_threshold = -1; // cmi threshold for core infogram - public double _safety_index_threshold = -1; // cmi threshold for safe infogram - public double _relevance_index_threshold = -1; // relevance threshold for safe infogram + public double _total_information_threshold = 0.1; // relevance threshold for core infogram + public double _net_information_threshold = 0.1; // cmi threshold for core infogram + public double _safety_index_threshold = 0.1; // cmi threshold for safe infogram + public double _relevance_index_threshold = 0.1; // relevance threshold for safe infogram public double _data_fraction = 1.0; // fraction of data to use to calculate infogram public Model.Parameters _infogram_algorithm_parameters; // store parameters of chosen algorithm public int _top_n_features = 50; // if 0 consider all predictors, otherwise, consider topk predictors diff --git a/h2o-admissibleml/src/main/java/hex/schemas/InfogramV3.java b/h2o-admissibleml/src/main/java/hex/schemas/InfogramV3.java index 1ca6c2dff9c0..55ac12d2032a 100644 --- a/h2o-admissibleml/src/main/java/hex/schemas/InfogramV3.java +++ b/h2o-admissibleml/src/main/java/hex/schemas/InfogramV3.java @@ -15,12 +15,10 @@ import water.api.SchemaServer; import water.api.schemas3.KeyV3; import water.api.schemas3.ModelParametersSchemaV3; -import static hex.util.DistributionUtils.distributionToFamily; + import java.util.*; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; + +import static hex.util.DistributionUtils.distributionToFamily; public class InfogramV3 extends ModelBuilderSchema { public static final class InfogramParametersV3 extends ModelParametersSchemaV3 { @@ -134,44 +132,43 @@ public static final class InfogramParametersV3 extends ModelParametersSchemaV3