Skip to content

Commit

Permalink
fix alpha warnings + HyperSpaceWalkerTest
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Feb 3, 2024
1 parent 83a4d6f commit 5f12f8a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 24 deletions.
6 changes: 3 additions & 3 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,10 @@ public DistributionFamily getDistributionFamily() {
}

@Override
public void addSearchFailureDetails(Grid.SearchFailure searchFailure, Grid grid) {
super.addSearchFailureDetails(searchFailure, grid);
public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) {
super.addSearchWarnings(searchFailure, grid);
if (ArrayUtils.contains(grid.getHyperNames(), "alpha")) {
// maybe we should find a way to raise this warning at the very beginning of grid search, similar to validation ini ModelBuilder#init().
// maybe we should find a way to raise this warning at the very beginning of grid search, similar to validation in ModelBuilder#init().
searchFailure.addWarning("Adding alpha array to hyperparameter runs slower with gridsearch. "+
"This is due to the fact that the algo has to run initialization for every alpha value. "+
"Setting the alpha array as a model parameter will skip the initialization and run faster overall.");
Expand Down
2 changes: 1 addition & 1 deletion h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ private Parameters getDefaults() {
* @param searchFailure
* @param grid
*/
public void addSearchFailureDetails(Grid.SearchFailure searchFailure, Grid grid) {}
public void addSearchWarnings(Grid.SearchFailure searchFailure, Grid grid) {}
}

public ModelMetrics addModelMetrics(final ModelMetrics mm) {
Expand Down
3 changes: 2 additions & 1 deletion h2o-core/src/main/java/hex/grid/Grid.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ private void appendFailedModelParameters(final Key<Model> modelKey, final MP par
_failures.put(searchedKey, searchFailure);
}
searchFailure.appendFailedModelParameters(params, rawParams, failureDetails, stackTrace);
if (params != null) params.addSearchFailureDetails(searchFailure, this);
if (params != null) params.addSearchWarnings(searchFailure, this);
}

static boolean isJobCanceled(final Throwable t) {
Expand Down Expand Up @@ -406,6 +406,7 @@ public SearchFailure getFailures() {
final Collection<SearchFailure> values = _failures.values();
// Original failures should be left intact. Also avoid mutability from outer space.
final SearchFailure searchFailure = new SearchFailure(_params != null ? _params.getClass() : null);
if (_params != null) _params.addSearchWarnings(searchFailure, this);

for (SearchFailure f : values) {
searchFailure.appendFailedModelParameters(f._failed_params, f._failed_raw_params, f._failure_details,
Expand Down
18 changes: 18 additions & 0 deletions h2o-core/src/test/java/hex/HyperSpaceWalkerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.junit.Test;
import water.TestUtil;
import water.test.dummy.DummyModelParameters;
import water.util.ReflectionUtils;

import java.util.HashMap;
import java.util.Map;
Expand All @@ -16,6 +17,17 @@ public class HyperSpaceWalkerTest extends TestUtil {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }

static public class DummyXGBoostModelParameters extends DummyModelParameters {

private static final DummyXGBoostModelParameters DEFAULTS;

static {
try {
DEFAULTS = DummyXGBoostModelParameters.class.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public int _max_depth;
public double _min_rows;
public double _sample_rate;
Expand All @@ -26,6 +38,12 @@ static public class DummyXGBoostModelParameters extends DummyModelParameters {
public float _reg_alpha;
public float _scale_pos_weight;
public float _max_delta_step;

@Override
public Object getParameterDefaultValue(String name) {
// tricking the default logic here as this parameters class is not properly registered, so we can't obtain the defaults the usual way.
return ReflectionUtils.getFieldValue(DEFAULTS, name);
}
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from builtins import range
import contextlib
from io import StringIO
import sys
sys.path.insert(1,"../../../")
import h2o
from tests import pyunit_utils
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.grid.grid_search import H2OGridSearch

try: # redirect python output
from StringIO import StringIO # for python 3
except ImportError:
from io import StringIO # for python 2

# This test is used to make sure when a user tries to set alpha in the hyper-parameter of gridsearch, a warning
# should appear to tell the user to set the alpha array as an parameter in the algorithm.
def grid_alpha_search():
Expand All @@ -25,23 +22,17 @@ def grid_alpha_search():
hyper_parameters = {'alpha': [0, 0.5]} # set hyper_parameters for grid search

print("Create models with lambda_search")
buffer = StringIO() # redirect output
sys.stderr=buffer
model_h2o_grid_search = H2OGridSearch(H2OGeneralizedLinearEstimator(family="tweedie", Lambda=0.5),
hyper_parameters)
model_h2o_grid_search.train(x=x, y=y, training_frame=hdf)
sys.stderr=sys.__stderr__ # redirect printout back to normal path
err = StringIO()
with contextlib.redirect_stderr(err):
model_h2o_grid_search = H2OGridSearch(H2OGeneralizedLinearEstimator(family="tweedie", Lambda=0.5),
hyper_parameters)
model_h2o_grid_search.train(x=x, y=y, training_frame=hdf)

# check and make sure we get the correct warning message
warn_phrase = "Adding alpha array to hyperparameter runs slower with gridsearch."
try: # for python 2.7
assert len(buffer.buflist)==warnNumber
print(buffer.buflist[0])
assert warn_phrase in buffer.buflist[0]
except: # for python 3.
warns = buffer.getvalue()
print("*** captured warning message: {0}".format(warns))
assert warn_phrase in warns
warns = err.getvalue()
print("*** captured warning message: {0}".format(warns))
assert warn_phrase in warns

if __name__ == "__main__":
pyunit_utils.standalone_test(grid_alpha_search)
Expand Down

0 comments on commit 5f12f8a

Please sign in to comment.