Skip to content

Commit

Permalink
grid pipeline support
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Jan 29, 2024
1 parent 6eebae0 commit 3c39c68
Show file tree
Hide file tree
Showing 20 changed files with 440 additions and 251 deletions.
12 changes: 12 additions & 0 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters.Link;
import hex.grid.Grid;
import hex.util.EffectiveParametersUtils;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
Expand Down Expand Up @@ -1029,6 +1030,17 @@ public DistributionFamily getDistributionFamily() {
return familyToDistribution(_family);
}

@Override
public void addSearchFailureDetails(Grid.SearchFailure searchFailure, Grid grid) {
super.addSearchFailureDetails(searchFailure, grid);
if (ArrayUtils.contains(grid.getHyperNames(), "alpha")) {
// maybe we should find a way to raise this warning at the very beginning of grid search, similar to validation ini ModelBuilder#init().
searchFailure.addWarning("Adding alpha array to hyperparameter runs slower with gridsearch. "+
"This is due to the fact that the algo has to run initialization for every alpha value. "+
"Setting the alpha array as a model parameter will skip the initialization and run faster overall.");
}
}

public void updateTweedieParams(double tweedieVariancePower, double tweedieLinkPower, double dispersion){
_tweedie_variance_power = tweedieVariancePower;
_tweedie_link_power = tweedieLinkPower;
Expand Down
21 changes: 9 additions & 12 deletions h2o-algos/src/test/java/hex/grid/GridTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import hex.faulttolerance.Recovery;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.grid.HyperSpaceWalker.BaseWalker.WalkerFactory;
import hex.tree.CompressedTree;
import hex.tree.gbm.GBMModel;
import hex.tree.uplift.UpliftDRFModel;
Expand All @@ -19,8 +18,6 @@
import water.exceptions.H2OGridException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.test.dummy.DummyAction;
import water.test.dummy.DummyModelParameters;
import water.test.dummy.MessageInstallAction;

Expand Down Expand Up @@ -72,7 +69,7 @@ public void testParallelModelTimeConstraint() {

Job<Grid> gridSearch = GridSearch.startGridSearch(
null, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory(),
new SimpleParametersBuilderFactory(),
searchCriteria, 2
);

Expand Down Expand Up @@ -110,7 +107,7 @@ public void testParallelUserStopRequest() {

Job<Grid> gridSearch = GridSearch.startGridSearch(
dest, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory(),
new SimpleParametersBuilderFactory(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
2
);
Expand Down Expand Up @@ -370,7 +367,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException
Scope.track(trainingFrame);
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory(),
new SimpleParametersBuilderFactory(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery1, 1
);
Expand Down Expand Up @@ -400,7 +397,7 @@ public void gridSearchRecoveryModels() throws IOException, InterruptedException
null, gridKey,
loadedGrid1.getParams(),
loadedGrid1.getHyperParams(),
new GridSearch.SimpleParametersBuilderFactory(),
new SimpleParametersBuilderFactory(),
loadedGrid1.getSearchCriteria(),
recovery2,
loadedGrid1.getParallelism()
Expand Down Expand Up @@ -457,7 +454,7 @@ public void gridSearchWithRecoverySuccess() throws IOException, InterruptedExcep
Key gridKey = Key.make("gridSearchWithRecovery_GRID");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -543,7 +540,7 @@ public void gridSearchWithRecoveryCancelGBM() throws IOException, InterruptedExc
Key gridKey = Key.make("gridSearchWithRecovery_GRID");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -586,7 +583,7 @@ public void gridSearchWithRecoveryCancelGLM() throws IOException, InterruptedExc
Key gridKey = Key.make("gridSearchWithRecoveryGlm");
Job<Grid> gs = GridSearch.startGridSearch(
null, gridKey, params, hyperParms,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.CartesianSearchCriteria(),
recovery, GridSearch.SEQUENTIAL_MODEL_BUILDING
);
Expand Down Expand Up @@ -850,7 +847,7 @@ public void test_parallel_random_search_with_max_models_being_less_than_parallel
params._train = trainingFrame._key;
params._response_column = "species";

GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory();
HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
int custom_max_model = 2;
hyperSpaceSearchCriteria.set_max_models(custom_max_model);
Expand Down Expand Up @@ -886,7 +883,7 @@ public void test_parallel_random_search_with_max_models_being_greater_than_paral
params._train = trainingFrame._key;
params._response_column = "species";

GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
SimpleParametersBuilderFactory simpleParametersBuilderFactory = new SimpleParametersBuilderFactory();
HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
int custom_max_model = 3;
hyperSpaceSearchCriteria.set_max_models(custom_max_model);
Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/test/java/hex/grid/SequentialWalkerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void test_SequentialWalker() {
new SequentialWalker<>(
gbmParameters,
hyperParams,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria()
),
GridSearch.SEQUENTIAL_MODEL_BUILDING
Expand Down Expand Up @@ -85,7 +85,7 @@ public void test_SequentialWalker_getHyperParams() {
SequentialWalker walker = new SequentialWalker<>(
gbmParameters,
hyperParams,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria()
);
Map<String, Object[]> exp = new HashMap<>();
Expand Down Expand Up @@ -124,7 +124,7 @@ public void test_SequentialWalker_supports_early_stopping() {
new SequentialWalker<>(
gbmParameters,
hyperParams,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new HyperSpaceSearchCriteria.SequentialSearchCriteria(StoppingCriteria.create()
.stoppingRounds(1)
.stoppingMetric(StoppingMetric.AUC)
Expand Down
7 changes: 2 additions & 5 deletions h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
import hex.ModelContainer;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.*;
import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
Expand Down Expand Up @@ -72,7 +69,7 @@ protected <MP extends Model.Parameters> Job<Grid> startSearch(
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
baseParams,
hyperParams,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hex.grid.HyperSpaceSearchCriteria.SequentialSearchCriteria;
import hex.grid.HyperSpaceSearchCriteria.StoppingCriteria;
import hex.grid.SequentialWalker;
import hex.grid.SimpleParametersBuilderFactory;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModel.XGBoostParameters;
import water.Job;
Expand Down Expand Up @@ -324,7 +325,7 @@ protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
new SequentialWalker<>(
params,
hyperParams,
new GridSearch.SimpleParametersBuilderFactory<>(),
new SimpleParametersBuilderFactory<>(),
new SequentialSearchCriteria(StoppingCriteria.create()
.maxRuntimeSecs((int)maxRuntimeSecs)
.stoppingMetric(params._stopping_metric)
Expand Down
8 changes: 6 additions & 2 deletions h2o-core/src/main/java/hex/ModelParametersBuilderFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public interface ModelParametersBuilderFactory<MP extends Model.Parameters> {
* @return this parameters builder
*/
ModelParametersBuilder<MP> get(MP initialParams);



/**
* Returns mapping from input parameter specification to
Expand All @@ -38,8 +40,10 @@ public interface ModelParametersBuilderFactory<MP extends Model.Parameters> {
*
* @param <MP> type of produced model parameters object
*/
interface ModelParametersBuilder<MP extends Model.Parameters> {

interface ModelParametersBuilder<MP extends Model.Parameters> {

boolean isAssignable(String name);

ModelParametersBuilder<MP> set(String name, Object value);

MP build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package hex;

import water.util.PojoUtils.FieldNaming;

/**
* This {@link ModelParametersBuilderFactory} delegates the hyper-parameters building logic
* to the initial {@link Model.Parameters} instance itself, using the {@link Parameterizable} methods.
* This allows better control for complex parameters objects that may this way accept nested hyper-parameters.
*/
public class ModelParametersDelegateBuilderFactory<MP extends Model.Parameters> implements ModelParametersBuilderFactory<MP> {

protected final FieldNaming fieldNaming;

public ModelParametersDelegateBuilderFactory() {
this(FieldNaming.CONSISTENT);
}

public ModelParametersDelegateBuilderFactory(FieldNaming fieldNaming) {
this.fieldNaming = fieldNaming;
}

@Override
public ModelParametersBuilder<MP> get(MP initialParams) {
return new DelegateParamsBuilder<>(initialParams, fieldNaming);
}

@Override
public FieldNaming getFieldNamingStrategy() {
return fieldNaming;
}

public static class DelegateParamsBuilder<MP extends Model.Parameters>
implements ModelParametersBuilder<MP> {

protected final MP params;
protected final FieldNaming fieldNaming;


protected DelegateParamsBuilder(MP params, FieldNaming fieldNaming) {
this.params = params;
this.fieldNaming = fieldNaming;
}

@Override
public boolean isAssignable(String name) {
return this.params.isParameterAssignable(fieldNaming.toDest(name));
}

@Override
public ModelParametersBuilder<MP> set(String name, Object value) {
this.params.setParameter(fieldNaming.toDest(name), value);
return this;
}

@Override
public MP build() {
return params;
}
}
}
103 changes: 103 additions & 0 deletions h2o-core/src/main/java/hex/ModelParametersGenericBuilderFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package hex;

import water.util.Log;
import water.util.PojoUtils;
import water.util.PojoUtils.FieldNaming;

import java.util.HashMap;
import java.util.Map;

/**
* A {@link ModelParametersBuilderFactory} that can dynamically generate parameters for any kind of model algorithm,
* as soon as one of the hyper-parameter is named {@value #ALGO_PARAM},
* in which case it is recommended to obtain a new builder using a {@link CommonModelParameters} instance,
* that will be used to provide the standard params for all type of algos.
*
* Otherwise, if there's no {@value #ALGO_PARAM} hyper-parameter, this factory behaves similarly to {@link ModelParametersBuilderFactory}.
*
* TODO: future improvement. When griding over multiple algos, we may want to apply different values for an hyper-parameter with the same name on algo-A and algo-B.
* In this case, we should be able to handle hyper-parameters differently based on naming convention. For example using `$` to prefix the param with the algo:
* - GBM$_max_depth = [3, 5, 7, 9, 11]
* - XGBoost$_max_depth = [5, 10, 15]
* as soon as the algo is defined, then the params are assigned this way:
* - if `_my_param` is provided, check if `Algo$_my_param` is also provided: if so then apply only the latter, otherwise apply the former.
*/
public class ModelParametersGenericBuilderFactory extends ModelParametersDelegateBuilderFactory<Model.Parameters> {

public static final String ALGO_PARAM = "algo";

/**
* A generic class containing only common {@link Model.Parameters} that can be used as initial common parameters
* when searching over multiple algos.
*/
public static class CommonModelParameters extends Model.Parameters {
@Override
public String algoName() {
return null;
}

@Override
public String fullName() {
return null;
}

@Override
public String javaName() {
return null;
}

@Override
public long progressUnits() {
return 0;
}
}

public ModelParametersGenericBuilderFactory() {
super();
}

public ModelParametersGenericBuilderFactory(FieldNaming fieldNaming) {
super(fieldNaming);
}

@Override
public ModelParametersBuilder<Model.Parameters> get(Model.Parameters initialParams) {
return new GenericParamsBuilder(initialParams, fieldNaming);
}

public static class GenericParamsBuilder extends DelegateParamsBuilder<Model.Parameters> {

private final Map<String, Object> hyperParams = new HashMap<>();

public GenericParamsBuilder(Model.Parameters params, FieldNaming fieldNaming) {
super(params, fieldNaming);
}

@Override
public ModelParametersBuilder<Model.Parameters> set(String name, Object value) {
hyperParams.put(name, value);
return this;
}

@Override
public Model.Parameters build() {
Model.Parameters result = params;
String algo = null;
if (hyperParams.containsKey(ALGO_PARAM)) {
algo = (String) hyperParams.get(ALGO_PARAM);
result = ModelBuilder.makeParameters(algo);
//add values from init params
PojoUtils.copyProperties(result, params, FieldNaming.CONSISTENT);
}
for (Map.Entry<String, Object> e : hyperParams.entrySet()) {
if (ALGO_PARAM.equals(e.getKey())) continue;
if (algo == null || result.hasParameter(fieldNaming.toDest(e.getKey()))) { // no check for `result.hasParameter` in case of strict algo, so that we can fail on invalid param
result.setParameter(fieldNaming.toDest(e.getKey()), e.getValue());
} else { // algo hyper-param was provided and this hyper-param is incompatible with it
Log.debug("Ignoring hyper-parameter `"+e.getKey()+"` unsupported by `"+algo+"`.");
}
}
return result;
}
}
}
2 changes: 1 addition & 1 deletion h2o-core/src/main/java/hex/faulttolerance/Recovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void autoRecover() {
Grid grid = Grid.importBinary(recoveryFile(resultKey), true);
GridSearch.resumeGridSearch(
jobKey, grid,
new GridSearchHandler.DefaultModelParametersBuilderFactory(),
new GridSearchHandler.APIModelParametersBuilderFactory(),
(Recovery<Grid>) this
);
} else {
Expand Down
Loading

0 comments on commit 3c39c68

Please sign in to comment.