Skip to content

Commit

Permalink
AutoML pipeline support
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Jan 29, 2024
1 parent c0e024b commit a13da23
Show file tree
Hide file tree
Showing 17 changed files with 483 additions and 51 deletions.
61 changes: 55 additions & 6 deletions h2o-automl/src/main/java/ai/h2o/automl/AutoML.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
import ai.h2o.automl.leaderboard.ModelProvider;
import ai.h2o.automl.leaderboard.ModelStep;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.leaderboard.*;
import hex.pipeline.DataTransformer;
import hex.pipeline.PipelineModel.PipelineParameters;
import hex.splitframe.ShuffleSplitFrame;
import org.apache.log4j.Logger;
import water.*;
import water.automl.api.schemas3.AutoMLV99;
import water.exceptions.H2OAutoMLException;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.nbhm.NonBlockingHashMap;
import water.util.*;

Expand Down Expand Up @@ -61,7 +63,7 @@ public enum Constraint {

private static final boolean verifyImmutability = true; // check that trainingFrame hasn't been messed with
private static final ThreadLocal<SimpleDateFormat> timestampFormatForKeys = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_HHmmss"));
private static final Logger log = LoggerFactory.getLogger(AutoML.class);
private static final Logger log = Logger.getLogger(AutoML.class);

private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(AutoML automl) {
final Key<AutoML> amlKey = automl._key;
Expand Down Expand Up @@ -166,9 +168,11 @@ public double[] getClassDistribution() {
private Vec[] _originalTrainingFrameVecs;
private String[] _originalTrainingFrameNames;
private long[] _originalTrainingFrameChecksums;
private transient NonBlockingHashMap<Key, String> _trackedKeys = new NonBlockingHashMap<>();
private transient Map<Key, String> _trackedKeys = new NonBlockingHashMap<>();
private transient ModelingStep[] _executionPlan;
private transient PreprocessingStep[] _preprocessing;
private transient PipelineParameters _pipelineParams;
private transient Map<String, Object[]> _pipelineHyperParams;
transient StepResultState[] _stepsResults;

private boolean _useAutoBlending;
Expand Down Expand Up @@ -218,6 +222,7 @@ public AutoML(Key<AutoML> key, Date startTime, AutoMLBuildSpec buildSpec) {
prepareData();
initLeaderboard();
initPreprocessing();
initPipeline();
_modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown);
} catch (Exception e) {
delete(); //cleanup potentially leaked keys
Expand Down Expand Up @@ -387,11 +392,53 @@ private void initLeaderboard() {
}
_leaderboard.setExtensionsProvider(createLeaderboardExtensionProvider(this));
}

private void initPipeline() {
final AutoMLBuildModels build = _buildSpec.build_models;
_pipelineParams = build.preprocessing == null || !build._pipelineEnabled ? null : new PipelineParameters();
if (_pipelineParams == null) return;
List<DataTransformer> transformers = new ArrayList<>();
Map<String, Object[]> hyperParams = new NonBlockingHashMap<>();
for (PreprocessingStepDefinition def : build.preprocessing) {
PreprocessingStep step = def.newPreprocessingStep(this);
transformers.addAll(Arrays.asList(step.pipelineTransformers()));
Map<String, Object[]> hp = step.pipelineTransformersHyperParams();
if (hp != null) hyperParams.putAll(hp);
}
if (transformers.isEmpty()) {
_pipelineParams = null;
_pipelineHyperParams = null;
} else {
_pipelineParams._transformers = transformers.toArray(new DataTransformer[0]);
_pipelineHyperParams = hyperParams;
}

//TODO: given that a transformer can reference a model (e.g. TE),
// and multiple transformers can refer
// to the same model,
// then we should be careful when deleting a transformer (resp. an entire pipeline)
// as we may delete sth that is still in use by another transformer (resp. pipeline).
// --> ref count?

//TODO: in AutoML, the same transformations are likely to occur on multiple (sometimes all) models,
// especially if the transformers parameters are not tuned.
// But it also depends if the transformers are context(CV)-sensitive (e.g. Target Encoding).
// See `CachingTransformer` for some thoughts about this.
}

PipelineParameters getPipelineParams() {
return _pipelineParams;
}

Map<String, Object[]> getPipelineHyperParams() {
return _pipelineHyperParams;
}

private void initPreprocessing() {
_preprocessing = _buildSpec.build_models.preprocessing == null
final AutoMLBuildModels build = _buildSpec.build_models;
_preprocessing = build.preprocessing == null || build._pipelineEnabled
? null
: Arrays.stream(_buildSpec.build_models.preprocessing)
: Arrays.stream(build.preprocessing)
.map(def -> def.newPreprocessingStep(this))
.toArray(PreprocessingStep[]::new);
}
Expand Down Expand Up @@ -491,9 +538,11 @@ public void run() {
eventLog().info(Stage.Workflow, "AutoML build started: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.start_time()))
.setNamedValue("start_epoch", _runCountdown.start_time(), EventLogEntry.epochFormat.get());
try {
Scope.enter();
learn();
} finally {
stop();
Scope.exit();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ public static final class AutoMLBuildModels extends Iced {
public double exploitation_ratio = -1;
public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters();
public PreprocessingStepDefinition[] preprocessing;
public boolean _pipelineEnabled = true; // currently used for testing until ready: to be removed
}

public static final class AutoMLCustomParameters extends Iced {
Expand Down
97 changes: 74 additions & 23 deletions h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
import hex.ModelContainer;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.*;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import hex.ModelParametersDelegateBuilderFactory;
import hex.pipeline.PipelineModel.PipelineParameters;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import water.*;
import water.KeyGen.ConstantKeyGen;
import water.KeyGen.PatternKeyGen;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Countdown;
Expand All @@ -37,6 +44,8 @@
* Parent class defining common properties and common logic for actual {@link AutoML} training steps.
*/
public abstract class ModelingStep<M extends Model> extends Iced<ModelingStep> {

protected static final String PIPELINE_KEY_PREFIX = "Pipeline_";

protected enum SeedPolicy {
/** No seed will be used (= random). */
Expand All @@ -61,20 +70,11 @@ protected <MP extends Model.Parameters> Job<Grid> startSearch(
assert baseParams != null;
assert hyperParams.size() > 0;
assert searchCriteria != null;
applyPreprocessing(baseParams);
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" hyperparameter search")
GridSearch.Builder builder = makeGridBuilder(resultKey, baseParams, hyperParams, searchCriteria);
aml().trackKeys(builder.dest());
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" hyperparameter search")
.setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get());
return GridSearch.create(
resultKey,
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
baseParams,
hyperParams,
new SimpleParametersBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
.withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures)
.start();
return builder.start();
}

@SuppressWarnings("unchecked")
Expand All @@ -84,11 +84,8 @@ protected <MP extends Model.Parameters> Job<M> startModel(
) {
assert resultKey != null;
assert params != null;
Job<M> job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description);
applyPreprocessing(params);
ModelBuilder builder = ModelBuilder.make(_algo.urlName(), job, (Key<Model>) resultKey);
builder._parms = params;
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" model training")
ModelBuilder builder = makeBuilder(resultKey, params);
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" model training")
.setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get());
builder.init(false); // validate parameters
if (builder._messages.length > 0) {
Expand All @@ -102,6 +99,38 @@ protected <MP extends Model.Parameters> Job<M> startModel(
}
return builder.trainModelOnH2ONode();
}

protected <MP extends Model.Parameters> GridSearch.Builder makeGridBuilder(Key<Grid> resultKey,
MP baseParams,
Map<String, Object[]> hyperParams,
HyperSpaceSearchCriteria searchCriteria) {
applyPreprocessing(baseParams);
Model.Parameters finalParams = applyPipeline(resultKey, baseParams, hyperParams);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);
return GridSearch.create(
resultKey,
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
finalParams,
hyperParams,
new ModelParametersDelegateBuilderFactory<>(),
searchCriteria
))
.withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING)
.withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures);
}


protected <MP extends Model.Parameters> ModelBuilder makeBuilder(Key<M> resultKey, MP params) {
applyPreprocessing(params);
Model.Parameters finalParams = applyPipeline(resultKey, params, null);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);

Job<M> job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description);
ModelBuilder builder = ModelBuilder.make(finalParams.algoName(), job, (Key<Model>) resultKey);
builder._parms = finalParams;
builder._input_parms = finalParams.clone();
return builder;
}

private boolean validParameters(Model.Parameters parms, String[] fields) {
try {
Expand Down Expand Up @@ -360,8 +389,6 @@ protected void setCommonModelBuilderParams(Model.Parameters params) {
setClassBalancingParams(params);
params._custom_metric_func = buildSpec.build_control.custom_metric_func;

params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;

/** Using _main_model_time_budget_factor to determine if and how we should restrict the time for the main model.
Expand All @@ -374,6 +401,8 @@ protected void setCommonModelBuilderParams(Model.Parameters params) {
protected void setCrossValidationParams(Model.Parameters params) {
AutoMLBuildSpec buildSpec = aml().getBuildSpec();
params._keep_cross_validation_predictions = aml().getBlendingFrame() == null || buildSpec.build_control.keep_cross_validation_predictions;
params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
params._fold_column = buildSpec.input_spec.fold_column;

if (buildSpec.input_spec.fold_column == null) {
Expand Down Expand Up @@ -413,6 +442,29 @@ protected void applyPreprocessing(Model.Parameters params) {
}
}

protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
if (aml().getPipelineParams() == null) return params;
PipelineParameters pparams = (PipelineParameters) aml().getPipelineParams().clone();
setCommonModelBuilderParams(pparams);
pparams._seed = params._seed;
pparams._max_runtime_secs = params._max_runtime_secs;
pparams._estimatorParams = params;
pparams._estimatorKeyGen = hyperParams == null
? new ConstantKeyGen(resultKey)
: new PatternKeyGen("{0}|s/"+PIPELINE_KEY_PREFIX+"//") // in case of grid, remove the Pipeline prefix to obtain the estimator key, this allows naming compatibility with the classic mode.
;
if (hyperParams != null) {
Map<String, Object[]> pipelineHyperParams = new HashMap<>();
for (Map.Entry<String, Object[]> e : hyperParams.entrySet()) {
pipelineHyperParams.put("estimator."+e.getKey(), e.getValue());
}
hyperParams.clear();
hyperParams.putAll(pipelineHyperParams);
hyperParams.putAll(aml().getPipelineHyperParams());
}
return pparams;
}

protected PreprocessingConfig getPreprocessingConfig() {
return new PreprocessingConfig();
}
Expand Down Expand Up @@ -638,7 +690,6 @@ protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters basePar
setSearchCriteria(searchCriteria, baseParms);

if (null == key) key = makeKey(_provider, true);
aml().trackKeys(key);

Log.debug("Hyperparameter search: " + _provider + ", time remaining (ms): " + aml().timeRemainingMs());
aml().eventLog().debug(Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0
Expand Down Expand Up @@ -758,7 +809,7 @@ protected Job<Models> startJob() {
final Key<Models> selectionKey = Key.make(key+"_select");
final EventLog selectionEventLog = EventLog.getOrMake(selectionKey);
// EventLog selectionEventLog = aml().eventLog();
final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog);
final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog);

{
result.delete_and_lock(job);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ StepResultState submit(ModelingStep step, Job parentJob) {
}
} catch (Exception e) {
resultState.addState(new StepResultState(step.getGlobalId(), e));
Log.err(e);
} finally {
step.onDone(job);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum ResultStatus {
enum Resolution {
sameAsMain, // resolves to the same state as the main step (ignoring other sub-step states).
optimistic, // success if any success, otherwise cancelled if any cancelled, otherwise failed if any failure, otherwise skipped.
pessimistic, // failures if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped.
pessimistic, // failed if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped.
}

private final String _id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ protected PreprocessingConfig getPreprocessingConfig() {
return config;
}

@Override
@Override
protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
return params; // no pipeline in SE, base models handle the transformations when making predictions.
}

@Override
@SuppressWarnings("unchecked")
public boolean canRun() {
Key<Model>[] keys = getBaseModels();
Expand Down Expand Up @@ -122,13 +127,18 @@ protected boolean hasDoppelganger(Key<Model>[] baseModelsKeys) {
protected abstract Key<Model>[] getBaseModels();

protected String getModelType(Key<Model> key) {
ModelingStep step = aml().session().getModelingStep(key);
if (step == null) { // dirty case
String keyStr = key.toString();
return keyStr.substring(0, keyStr.indexOf('_'));
int lookupStart = keyStr.startsWith(PIPELINE_KEY_PREFIX) ? PIPELINE_KEY_PREFIX.length() : 0;
return keyStr.substring(lookupStart, keyStr.indexOf('_', lookupStart));
} else {
return step.getAlgo().name();
}
}

protected boolean isStackedEnsemble(Key<Model> key) {
ModelingStep step = aml().session().getModelingStep(key);
return step != null && step.getAlgo() == Algo.StackedEnsemble;
return Algo.StackedEnsemble.name().equals(getModelType(key));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ public Map<String, Object[]> prepareSearchParameters() {
XGBoostParameters.Booster.gbtree,
XGBoostParameters.Booster.dart
});

// searchParams.put("_booster$weights", new Integer[] {2, 1});

searchParams.put("_reg_lambda", new Float[]{0.001f, 0.01f, 0.1f, 1f, 10f, 100f});
searchParams.put("_reg_alpha", new Float[]{0.001f, 0.01f, 0.1f, 0.5f, 1f});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package ai.h2o.automl.preprocessing;

import ai.h2o.automl.ModelingStep;
import hex.Model;
import hex.pipeline.DataTransformer;

import java.util.Map;

public interface PreprocessingStep<T> {

Expand Down Expand Up @@ -34,4 +36,8 @@ interface Completer extends Runnable {}
*/
void remove();

DataTransformer[] pipelineTransformers();

Map<String, Object[]> pipelineTransformersHyperParams();

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.h2o.automl.preprocessing;

import ai.h2o.automl.AutoML;
import hex.pipeline.DataTransformer;
import water.Iced;

public class PreprocessingStepDefinition extends Iced<PreprocessingStepDefinition> {
Expand Down
Loading

0 comments on commit a13da23

Please sign in to comment.