Skip to content

Commit

Permalink
Revert "GH-15857: cleanup legacy TE integration in ModelBuilder and A…
Browse files Browse the repository at this point in the history
…utoML (#16061)"

This reverts commit a8f309b.
  • Loading branch information
valenad1 committed Mar 8, 2024
1 parent 6830131 commit 0af6327
Show file tree
Hide file tree
Showing 74 changed files with 1,008 additions and 1,273 deletions.
42 changes: 34 additions & 8 deletions h2o-automl/src/main/java/ai/h2o/automl/AutoML.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import ai.h2o.automl.leaderboard.ModelGroup;
import ai.h2o.automl.leaderboard.ModelProvider;
import ai.h2o.automl.leaderboard.ModelStep;
import ai.h2o.automl.preprocessing.PipelineStep;
import ai.h2o.automl.preprocessing.PipelineStepDefinition;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
Expand Down Expand Up @@ -170,6 +170,7 @@ public double[] getClassDistribution() {
private long[] _originalTrainingFrameChecksums;
private transient Map<Key, String> _trackedKeys = new NonBlockingHashMap<>();
private transient ModelingStep[] _executionPlan;
private transient PreprocessingStep[] _preprocessing;
private transient PipelineParameters _pipelineParams;
private transient Map<String, Object[]> _pipelineHyperParams;
transient StepResultState[] _stepsResults;
Expand Down Expand Up @@ -220,6 +221,7 @@ public AutoML(Key<AutoML> key, Date startTime, AutoMLBuildSpec buildSpec) {

prepareData();
initLeaderboard();
initPreprocessing();
initPipeline();
_modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown);
} catch (Exception e) {
Expand Down Expand Up @@ -393,12 +395,12 @@ private void initLeaderboard() {

private void initPipeline() {
final AutoMLBuildModels build = _buildSpec.build_models;
_pipelineParams = build.preprocessing == null ? null : new PipelineParameters();
_pipelineParams = build.preprocessing == null || !build._pipelineEnabled ? null : new PipelineParameters();
if (_pipelineParams == null) return;
List<DataTransformer> transformers = new ArrayList<>();
Map<String, Object[]> hyperParams = new NonBlockingHashMap<>();
for (PipelineStepDefinition def : build.preprocessing) {
PipelineStep step = def.newPipelineStep(this);
for (PreprocessingStepDefinition def : build.preprocessing) {
PreprocessingStep step = def.newPreprocessingStep(this);
transformers.addAll(Arrays.asList(step.pipelineTransformers()));
Map<String, Object[]> hp = step.pipelineTransformersHyperParams();
if (hp != null) hyperParams.putAll(hp);
Expand All @@ -407,13 +409,13 @@ private void initPipeline() {
_pipelineParams = null;
_pipelineHyperParams = null;
} else {
_pipelineParams.setTransformers(transformers.toArray(new DataTransformer[0]));
_pipelineParams._transformers = transformers.toArray(new DataTransformer[0]);
_pipelineHyperParams = hyperParams;
trackKeys(transformers.stream().map(DataTransformer::getKey).toArray(Key[]::new));
}

//TODO: given that a transformer can reference a model (e.g. TE),
// and multiple transformers can refer the same model,
// and multiple transformers can refer
// to the same model,
// then we should be careful when deleting a transformer (resp. an entire pipeline)
// as we may delete sth that is still in use by another transformer (resp. pipeline).
// --> ref count?
Expand All @@ -432,6 +434,19 @@ Map<String, Object[]> getPipelineHyperParams() {
return _pipelineHyperParams;
}

private void initPreprocessing() {
final AutoMLBuildModels build = _buildSpec.build_models;
_preprocessing = build.preprocessing == null || build._pipelineEnabled
? null
: Arrays.stream(build.preprocessing)
.map(def -> def.newPreprocessingStep(this))
.toArray(PreprocessingStep[]::new);
}

PreprocessingStep[] getPreprocessing() {
return _preprocessing;
}

ModelingStep[] getExecutionPlan() {
if (_executionPlan == null) {
_executionPlan = session().getModelingStepsRegistry().getOrderedSteps(selectModelingPlan(null), this);
Expand Down Expand Up @@ -793,6 +808,9 @@ private void prepareData() {

private void learn() {
List<ModelingStep> completed = new ArrayList<>();
if (_preprocessing != null) {
for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.prepare();
}
for (ModelingStep step : getExecutionPlan()) {
if (!exceededSearchLimits(step)) {
StepResultState state = _modelingStepsExecutor.submit(step, job());
Expand All @@ -812,6 +830,9 @@ private void learn() {
}
}
}
if (_preprocessing != null) {
for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.dispose();
}
_actualModelingSteps = session().getModelingStepsRegistry().createDefinitionPlanFromSteps(completed.toArray(new ModelingStep[0]));
eventLog().info(Stage.Workflow, "Actual modeling steps: "+Arrays.toString(_actualModelingSteps));
}
Expand Down Expand Up @@ -871,6 +892,11 @@ protected Futures remove_impl(Futures fs, boolean cascade) {
if (leaderboard() != null) leaderboard().remove(fs, cascade);
if (eventLog() != null) eventLog().remove(fs, cascade);
if (session() != null) session().remove(fs, cascade);
if (cascade && _preprocessing != null) {
for (PreprocessingStep preprocessingStep : _preprocessing) {
preprocessingStep.remove();
}
}
for (Key key : _trackedKeys.keySet()) Keyed.remove(key, fs, true);

return super.remove_impl(fs, cascade);
Expand Down
5 changes: 3 additions & 2 deletions h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.h2o.automl;

import ai.h2o.automl.preprocessing.PipelineStepDefinition;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
Expand Down Expand Up @@ -180,7 +180,8 @@ public static final class AutoMLBuildModels extends Iced {
public StepDefinition[] modeling_plan;
public double exploitation_ratio = -1;
public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters();
public PipelineStepDefinition[] preprocessing;
public PreprocessingStepDefinition[] preprocessing;
public boolean _pipelineEnabled = false; // currently used for testing until ready: to be removed
}

public static final class AutoMLCustomParameters extends Iced {
Expand Down
74 changes: 26 additions & 48 deletions h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.events.EventLogEntry.Stage;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import hex.Model;
import hex.Model.Parameters.FoldAssignmentScheme;
import hex.ModelBuilder;
Expand All @@ -22,22 +24,21 @@
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import hex.ModelParametersDelegateBuilderFactory;
import hex.pipeline.DataTransformer;
import hex.pipeline.PipelineModel.PipelineParameters;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import water.*;
import water.KeyGen.ConstantKeyGen;
import water.KeyGen.PatternKeyGen;
import water.exceptions.H2OIllegalArgumentException;
import water.util.*;
import water.util.ArrayUtils;
import water.util.Countdown;
import water.util.EnumUtils;
import water.util.Log;

import java.util.*;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static hex.pipeline.PipelineModel.ESTIMATOR_PARAM;

/**
* Parent class defining common properties and common logic for actual {@link AutoML} training steps.
Expand Down Expand Up @@ -70,6 +71,7 @@ protected <MP extends Model.Parameters> Job<Grid> startSearch(
assert hyperParams.size() > 0;
assert searchCriteria != null;
GridSearch.Builder builder = makeGridBuilder(resultKey, baseParams, hyperParams, searchCriteria);
aml().trackKeys(builder.dest());
aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" hyperparameter search")
.setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get());
return builder.start();
Expand Down Expand Up @@ -102,12 +104,9 @@ protected <MP extends Model.Parameters> GridSearch.Builder makeGridBuilder(Key<G
MP baseParams,
Map<String, Object[]> hyperParams,
HyperSpaceSearchCriteria searchCriteria) {
applyPreprocessing(baseParams);
Model.Parameters finalParams = applyPipeline(resultKey, baseParams, hyperParams);
if (finalParams instanceof PipelineParameters) {
resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);
aml().trackKeys(((PipelineParameters)finalParams)._transformers);
}
aml().trackKeys(resultKey);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);
return GridSearch.create(
resultKey,
HyperSpaceWalker.BaseWalker.WalkerFactory.create(
Expand All @@ -122,6 +121,7 @@ protected <MP extends Model.Parameters> GridSearch.Builder makeGridBuilder(Key<G


protected <MP extends Model.Parameters> ModelBuilder makeBuilder(Key<M> resultKey, MP params) {
applyPreprocessing(params);
Model.Parameters finalParams = applyPipeline(resultKey, params, null);
if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey);

Expand Down Expand Up @@ -433,44 +433,18 @@ protected void setCustomParams(Model.Parameters params) {
if (customParams == null) return;
customParams.applyCustomParameters(_algo, params);
}


/**
* If some algo/provider needs to modify the pipeline dynamically, it's recommended to override this.
*/
protected void filterPipelineTransformers(List<DataTransformer> transformers, Map<String, Object[]> transformerHyperParams) {}

protected final void removeTransformersType(Class<? extends DataTransformer> toRemove, List<DataTransformer> transformers, Map<String, Object[]> transformerHyperParams) {
List<String> teIds = transformers.stream()
.filter(toRemove::isInstance)
.map(DataTransformer::name)
.collect(Collectors.toList());
transformers.removeIf(dt -> teIds.contains(dt.name()));
transformerHyperParams.keySet().removeIf(k -> teIds.contains(k.split("\\.", 2)[0]));
}

/**
* Transforms the simple model parameters and potential hyper-parameters into pipeline parameters.
*
* @param resultKey: the key of the final pipe
* @param params: parameters for the model being built in this step.
* @param hyperParams: hyper-parameters for the grid being built in this step (can be null if simple model).
* @return the final pipeline parameters that will be used to build the models in this step.
*/

protected void applyPreprocessing(Model.Parameters params) {
if (aml().getPreprocessing() == null) return;
for (PreprocessingStep preprocessingStep : aml().getPreprocessing()) {
PreprocessingStep.Completer complete = preprocessingStep.apply(params, getPreprocessingConfig());
_onDone.add(j -> complete.run());
}
}

protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
if (aml().getPipelineParams() == null) return params;
PipelineParameters pparams = aml().getPipelineParams().freshCopy();
List<DataTransformer> transformers = new ArrayList<>(Arrays.asList(pparams.getTransformers())); // need to convert to ArrayList as `filterPipelineTransformers` may remove items below
Map<String, Object[]> transformersHyperParams = new HashMap<>(aml().getPipelineHyperParams());
filterPipelineTransformers(transformers, transformersHyperParams);
Key[] defaultTransformersKeys = pparams._transformers;
pparams.setTransformers(transformers.toArray(new DataTransformer[0]));
if (defaultTransformersKeys.length != pparams._transformers.length) {
for (Key k : defaultTransformersKeys) {
if (!ArrayUtils.contains(pparams._transformers, k)) ((DataTransformer)k.get()).cleanup();
}
}
if (pparams._transformers.length == 0) return params;
PipelineParameters pparams = (PipelineParameters) aml().getPipelineParams().clone();
setCommonModelBuilderParams(pparams);
pparams._seed = params._seed;
pparams._max_runtime_secs = params._max_runtime_secs;
Expand All @@ -482,15 +456,19 @@ protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params,
if (hyperParams != null) {
Map<String, Object[]> pipelineHyperParams = new HashMap<>();
for (Map.Entry<String, Object[]> e : hyperParams.entrySet()) {
pipelineHyperParams.put(ESTIMATOR_PARAM+"."+e.getKey(), e.getValue());
pipelineHyperParams.put("estimator."+e.getKey(), e.getValue());
}
hyperParams.clear();
hyperParams.putAll(pipelineHyperParams);
hyperParams.putAll(transformersHyperParams);
hyperParams.putAll(aml().getPipelineHyperParams());
}
return pparams;
}

protected PreprocessingConfig getPreprocessingConfig() {
return new PreprocessingConfig();
}

/**
* Configures early-stopping for the model or set of models to be built.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package ai.h2o.automl.modeling;

import ai.h2o.automl.*;
import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformer;
import hex.Model;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningModel.DeepLearningParameters;
import hex.pipeline.DataTransformer;
import water.Key;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;


public class DeepLearningStepsProvider
Expand All @@ -26,17 +22,14 @@ static abstract class DeepLearningModelStep extends ModelingStep.ModelStep<DeepL
public DeepLearningModelStep(String id, AutoML autoML) {
super(NAME, Algo.DeepLearning, id, autoML);
}

@Override
protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
return super.applyPipeline(resultKey, params, hyperParams);
}

@Override
protected void filterPipelineTransformers(List<DataTransformer> transformers, Map<String, Object[]> transformerHyperParams) {
// legacy behavior: TE was not applied for deep learning as it is not useful for this algo.
removeTransformersType(TargetEncoderFeatureTransformer.class, transformers, transformerHyperParams);
}

@Override
protected PreprocessingConfig getPreprocessingConfig() {
//TE useless for DNN
PreprocessingConfig config = super.getPreprocessingConfig();
config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled());
return config;
}
}

static abstract class DeepLearningGridStep extends ModelingStep.GridStep<DeepLearningModel> {
Expand All @@ -53,6 +46,14 @@ public DeepLearningParameters prepareModelParameters() {
return params;
}

@Override
protected PreprocessingConfig getPreprocessingConfig() {
//TE useless for DNN
PreprocessingConfig config = super.getPreprocessingConfig();
config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled());
return config;
}

public Map<String, Object[]> prepareSearchParameters() {
Map<String, Object[]> searchParams = new HashMap<>();
searchParams.put("_rho", new Double[] { 0.9, 0.95, 0.99 });
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package ai.h2o.automl.modeling;

import ai.h2o.automl.*;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMParameters;

Expand Down Expand Up @@ -30,6 +33,15 @@ public GLMParameters prepareModelParameters() {
params._lambda_search = true;
return params;
}

@Override
protected PreprocessingConfig getPreprocessingConfig() {
//GLM (the exception as usual) doesn't support targetencoding if CV is enabled
// because it is initializing its lambdas + other params before CV (preventing changes in train frame during CV).
PreprocessingConfig config = super.getPreprocessingConfig();
config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled());
return config;
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ai.h2o.automl.*;
import ai.h2o.automl.WorkAllocations.Work;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.KeyValue;
import hex.Model;
import hex.ensemble.Metalearner;
Expand Down Expand Up @@ -63,11 +65,19 @@ protected void setClassBalancingParams(Model.Parameters params) {
}

@Override
protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
return params; // no pipeline in SE, base models handle the transformations when making predictions.
protected PreprocessingConfig getPreprocessingConfig() {
//SE should not have TE applied, the base models already do it.
PreprocessingConfig config = super.getPreprocessingConfig();
config.put(TargetEncoding.CONFIG_ENABLED, false);
return config;
}

@Override
@Override
protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map<String, Object[]> hyperParams) {
return params; // no pipeline in SE, base models handle the transformations when making predictions.
}

@Override
@SuppressWarnings("unchecked")
public boolean canRun() {
Key<Model>[] keys = getBaseModels();
Expand Down
Loading

0 comments on commit 0af6327

Please sign in to comment.