diff --git a/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java b/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java index 2d0332d60b09..b3ba0a701c4a 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/AutoML.java @@ -12,19 +12,21 @@ import ai.h2o.automl.leaderboard.ModelProvider; import ai.h2o.automl.leaderboard.ModelStep; import ai.h2o.automl.preprocessing.PreprocessingStep; +import ai.h2o.automl.preprocessing.PreprocessingStepDefinition; import hex.Model; import hex.ScoreKeeper.StoppingMetric; import hex.genmodel.utils.DistributionFamily; import hex.leaderboard.*; +import hex.pipeline.DataTransformer; +import hex.pipeline.PipelineModel.PipelineParameters; import hex.splitframe.ShuffleSplitFrame; +import org.apache.log4j.Logger; import water.*; import water.automl.api.schemas3.AutoMLV99; import water.exceptions.H2OAutoMLException; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.fvec.Vec; -import water.logging.Logger; -import water.logging.LoggerFactory; import water.nbhm.NonBlockingHashMap; import water.util.*; @@ -61,7 +63,7 @@ public enum Constraint { private static final boolean verifyImmutability = true; // check that trainingFrame hasn't been messed with private static final ThreadLocal timestampFormatForKeys = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_HHmmss")); - private static final Logger log = LoggerFactory.getLogger(AutoML.class); + private static final Logger log = Logger.getLogger(AutoML.class); private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(AutoML automl) { final Key amlKey = automl._key; @@ -166,9 +168,11 @@ public double[] getClassDistribution() { private Vec[] _originalTrainingFrameVecs; private String[] _originalTrainingFrameNames; private long[] _originalTrainingFrameChecksums; - private transient NonBlockingHashMap _trackedKeys = new NonBlockingHashMap<>(); + private transient Map _trackedKeys = new NonBlockingHashMap<>(); private transient ModelingStep[] _executionPlan; private transient PreprocessingStep[] _preprocessing; + private transient PipelineParameters _pipelineParams; + private transient Map _pipelineHyperParams; transient StepResultState[] _stepsResults; private boolean _useAutoBlending; @@ -218,6 +222,7 @@ public AutoML(Key key, Date startTime, AutoMLBuildSpec buildSpec) { prepareData(); initLeaderboard(); initPreprocessing(); + initPipeline(); _modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown); } catch (Exception e) { delete(); //cleanup potentially leaked keys @@ -387,11 +392,53 @@ private void initLeaderboard() { } _leaderboard.setExtensionsProvider(createLeaderboardExtensionProvider(this)); } + + private void initPipeline() { + final AutoMLBuildModels build = _buildSpec.build_models; + _pipelineParams = build.preprocessing == null || !build._pipelineEnabled ? null : new PipelineParameters(); + if (_pipelineParams == null) return; + List transformers = new ArrayList<>(); + Map hyperParams = new NonBlockingHashMap<>(); + for (PreprocessingStepDefinition def : build.preprocessing) { + PreprocessingStep step = def.newPreprocessingStep(this); + transformers.addAll(Arrays.asList(step.pipelineTransformers())); + Map hp = step.pipelineTransformersHyperParams(); + if (hp != null) hyperParams.putAll(hp); + } + if (transformers.isEmpty()) { + _pipelineParams = null; + _pipelineHyperParams = null; + } else { + _pipelineParams._transformers = transformers.toArray(new DataTransformer[0]); + _pipelineHyperParams = hyperParams; + } + + //TODO: given that a transformer can reference a model (e.g. TE), + // and multiple transformers can refer + // to the same model, + // then we should be careful when deleting a transformer (resp. an entire pipeline) + // as we may delete sth that is still in use by another transformer (resp. pipeline). + // --> ref count? + + //TODO: in AutoML, the same transformations are likely to occur on multiple (sometimes all) models, + // especially if the transformers parameters are not tuned. + // But it also depends if the transformers are context(CV)-sensitive (e.g. Target Encoding). + // See `CachingTransformer` for some thoughts about this. + } + + PipelineParameters getPipelineParams() { + return _pipelineParams; + } + + Map getPipelineHyperParams() { + return _pipelineHyperParams; + } private void initPreprocessing() { - _preprocessing = _buildSpec.build_models.preprocessing == null + final AutoMLBuildModels build = _buildSpec.build_models; + _preprocessing = build.preprocessing == null || build._pipelineEnabled ? null - : Arrays.stream(_buildSpec.build_models.preprocessing) + : Arrays.stream(build.preprocessing) .map(def -> def.newPreprocessingStep(this)) .toArray(PreprocessingStep[]::new); } @@ -491,9 +538,11 @@ public void run() { eventLog().info(Stage.Workflow, "AutoML build started: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.start_time())) .setNamedValue("start_epoch", _runCountdown.start_time(), EventLogEntry.epochFormat.get()); try { + Scope.enter(); learn(); } finally { stop(); + Scope.exit(); } } diff --git a/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java b/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java index 6c742337daf3..e46a53efaa18 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/AutoMLBuildSpec.java @@ -181,6 +181,7 @@ public static final class AutoMLBuildModels extends Iced { public double exploitation_ratio = -1; public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters(); public PreprocessingStepDefinition[] preprocessing; + public boolean _pipelineEnabled = true; // currently used for testing until ready: to be removed } public static final class AutoMLCustomParameters extends Iced { diff --git a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java index 6095530bed85..85b9011c7682 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStep.java @@ -17,12 +17,19 @@ import hex.ModelContainer; import hex.ScoreKeeper.StoppingMetric; import hex.genmodel.utils.DistributionFamily; -import hex.grid.*; +import hex.grid.Grid; +import hex.grid.GridSearch; +import hex.grid.HyperSpaceSearchCriteria; import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria; +import hex.grid.HyperSpaceWalker; import hex.leaderboard.Leaderboard; +import hex.ModelParametersDelegateBuilderFactory; +import hex.pipeline.PipelineModel.PipelineParameters; import jsr166y.CountedCompleter; import org.apache.commons.lang.builder.ToStringBuilder; import water.*; +import water.KeyGen.ConstantKeyGen; +import water.KeyGen.PatternKeyGen; import water.exceptions.H2OIllegalArgumentException; import water.util.ArrayUtils; import water.util.Countdown; @@ -37,6 +44,8 @@ * Parent class defining common properties and common logic for actual {@link AutoML} training steps. */ public abstract class ModelingStep extends Iced { + + protected static final String PIPELINE_KEY_PREFIX = "Pipeline_"; protected enum SeedPolicy { /** No seed will be used (= random). */ @@ -61,20 +70,11 @@ protected Job startSearch( assert baseParams != null; assert hyperParams.size() > 0; assert searchCriteria != null; - applyPreprocessing(baseParams); - aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" hyperparameter search") + GridSearch.Builder builder = makeGridBuilder(resultKey, baseParams, hyperParams, searchCriteria); + aml().trackKeys(builder.dest()); + aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" hyperparameter search") .setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get()); - return GridSearch.create( - resultKey, - HyperSpaceWalker.BaseWalker.WalkerFactory.create( - baseParams, - hyperParams, - new SimpleParametersBuilderFactory<>(), - searchCriteria - )) - .withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING) - .withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures) - .start(); + return builder.start(); } @SuppressWarnings("unchecked") @@ -84,11 +84,8 @@ protected Job startModel( ) { assert resultKey != null; assert params != null; - Job job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description); - applyPreprocessing(params); - ModelBuilder builder = ModelBuilder.make(_algo.urlName(), job, (Key) resultKey); - builder._parms = params; - aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+resultKey+" model training") + ModelBuilder builder = makeBuilder(resultKey, params); + aml().eventLog().info(Stage.ModelTraining, "AutoML: starting "+builder.dest()+" model training") .setNamedValue("start_"+_provider+"_"+_id, new Date(), EventLogEntry.epochFormat.get()); builder.init(false); // validate parameters if (builder._messages.length > 0) { @@ -102,6 +99,38 @@ protected Job startModel( } return builder.trainModelOnH2ONode(); } + + protected GridSearch.Builder makeGridBuilder(Key resultKey, + MP baseParams, + Map hyperParams, + HyperSpaceSearchCriteria searchCriteria) { + applyPreprocessing(baseParams); + Model.Parameters finalParams = applyPipeline(resultKey, baseParams, hyperParams); + if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey); + return GridSearch.create( + resultKey, + HyperSpaceWalker.BaseWalker.WalkerFactory.create( + finalParams, + hyperParams, + new ModelParametersDelegateBuilderFactory<>(), + searchCriteria + )) + .withParallelism(GridSearch.SEQUENTIAL_MODEL_BUILDING) + .withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures); + } + + + protected ModelBuilder makeBuilder(Key resultKey, MP params) { + applyPreprocessing(params); + Model.Parameters finalParams = applyPipeline(resultKey, params, null); + if (finalParams instanceof PipelineParameters) resultKey = Key.make(PIPELINE_KEY_PREFIX+resultKey); + + Job job = new Job<>(resultKey, ModelBuilder.javaName(_algo.urlName()), _description); + ModelBuilder builder = ModelBuilder.make(finalParams.algoName(), job, (Key) resultKey); + builder._parms = finalParams; + builder._input_parms = finalParams.clone(); + return builder; + } private boolean validParameters(Model.Parameters parms, String[] fields) { try { @@ -360,8 +389,6 @@ protected void setCommonModelBuilderParams(Model.Parameters params) { setClassBalancingParams(params); params._custom_metric_func = buildSpec.build_control.custom_metric_func; - params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models; - params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment; params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir; /** Using _main_model_time_budget_factor to determine if and how we should restrict the time for the main model. @@ -374,6 +401,8 @@ protected void setCommonModelBuilderParams(Model.Parameters params) { protected void setCrossValidationParams(Model.Parameters params) { AutoMLBuildSpec buildSpec = aml().getBuildSpec(); params._keep_cross_validation_predictions = aml().getBlendingFrame() == null || buildSpec.build_control.keep_cross_validation_predictions; + params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models; + params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment; params._fold_column = buildSpec.input_spec.fold_column; if (buildSpec.input_spec.fold_column == null) { @@ -413,6 +442,29 @@ protected void applyPreprocessing(Model.Parameters params) { } } + protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map hyperParams) { + if (aml().getPipelineParams() == null) return params; + PipelineParameters pparams = (PipelineParameters) aml().getPipelineParams().clone(); + setCommonModelBuilderParams(pparams); + pparams._seed = params._seed; + pparams._max_runtime_secs = params._max_runtime_secs; + pparams._estimatorParams = params; + pparams._estimatorKeyGen = hyperParams == null + ? new ConstantKeyGen(resultKey) + : new PatternKeyGen("{0}|s/"+PIPELINE_KEY_PREFIX+"//") // in case of grid, remove the Pipeline prefix to obtain the estimator key, this allows naming compatibility with the classic mode. + ; + if (hyperParams != null) { + Map pipelineHyperParams = new HashMap<>(); + for (Map.Entry e : hyperParams.entrySet()) { + pipelineHyperParams.put("estimator."+e.getKey(), e.getValue()); + } + hyperParams.clear(); + hyperParams.putAll(pipelineHyperParams); + hyperParams.putAll(aml().getPipelineHyperParams()); + } + return pparams; + } + protected PreprocessingConfig getPreprocessingConfig() { return new PreprocessingConfig(); } @@ -638,7 +690,6 @@ protected Job hyperparameterSearch(Key key, Model.Parameters basePar setSearchCriteria(searchCriteria, baseParms); if (null == key) key = makeKey(_provider, true); - aml().trackKeys(key); Log.debug("Hyperparameter search: " + _provider + ", time remaining (ms): " + aml().timeRemainingMs()); aml().eventLog().debug(Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0 @@ -758,7 +809,7 @@ protected Job startJob() { final Key selectionKey = Key.make(key+"_select"); final EventLog selectionEventLog = EventLog.getOrMake(selectionKey); // EventLog selectionEventLog = aml().eventLog(); -final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog); + final LeaderboardHolder selectionLeaderboard = makeLeaderboard(selectionKey.toString(), selectionEventLog); { result.delete_and_lock(job); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java index 727d2637cf8a..dcbbc4ea4c02 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/ModelingStepsExecutor.java @@ -106,6 +106,7 @@ StepResultState submit(ModelingStep step, Job parentJob) { } } catch (Exception e) { resultState.addState(new StepResultState(step.getGlobalId(), e)); + Log.err(e); } finally { step.onDone(job); } diff --git a/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java b/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java index 5e408bb746e0..65f5fbaf2c05 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/StepResultState.java @@ -19,7 +19,7 @@ enum ResultStatus { enum Resolution { sameAsMain, // resolves to the same state as the main step (ignoring other sub-step states). optimistic, // success if any success, otherwise cancelled if any cancelled, otherwise failed if any failure, otherwise skipped. - pessimistic, // failures if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped. + pessimistic, // failed if any failure, otherwise cancelled if any cancelled, otherwise success it any success, otherwise skipped. } private final String _id; diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java index a16adf2b62e9..86c537f761c8 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/StackedEnsembleStepsProvider.java @@ -72,7 +72,12 @@ protected PreprocessingConfig getPreprocessingConfig() { return config; } - @Override + @Override + protected Model.Parameters applyPipeline(Key resultKey, Model.Parameters params, Map hyperParams) { + return params; // no pipeline in SE, base models handle the transformations when making predictions. + } + + @Override @SuppressWarnings("unchecked") public boolean canRun() { Key[] keys = getBaseModels(); @@ -122,13 +127,18 @@ protected boolean hasDoppelganger(Key[] baseModelsKeys) { protected abstract Key[] getBaseModels(); protected String getModelType(Key key) { + ModelingStep step = aml().session().getModelingStep(key); + if (step == null) { // dirty case String keyStr = key.toString(); - return keyStr.substring(0, keyStr.indexOf('_')); + int lookupStart = keyStr.startsWith(PIPELINE_KEY_PREFIX) ? PIPELINE_KEY_PREFIX.length() : 0; + return keyStr.substring(lookupStart, keyStr.indexOf('_', lookupStart)); + } else { + return step.getAlgo().name(); + } } protected boolean isStackedEnsemble(Key key) { - ModelingStep step = aml().session().getModelingStep(key); - return step != null && step.getAlgo() == Algo.StackedEnsemble; + return Algo.StackedEnsemble.name().equals(getModelType(key)); } @Override diff --git a/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java b/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java index bb8089fe3c67..01033767ceb3 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/modeling/XGBoostSteps.java @@ -201,7 +201,8 @@ public Map prepareSearchParameters() { XGBoostParameters.Booster.gbtree, XGBoostParameters.Booster.dart }); - +// searchParams.put("_booster$weights", new Integer[] {2, 1}); + searchParams.put("_reg_lambda", new Float[]{0.001f, 0.01f, 0.1f, 1f, 10f, 100f}); searchParams.put("_reg_alpha", new Float[]{0.001f, 0.01f, 0.1f, 0.5f, 1f}); diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java index e3a32a361c71..59e0b84e2d98 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStep.java @@ -1,7 +1,9 @@ package ai.h2o.automl.preprocessing; -import ai.h2o.automl.ModelingStep; import hex.Model; +import hex.pipeline.DataTransformer; + +import java.util.Map; public interface PreprocessingStep { @@ -34,4 +36,8 @@ interface Completer extends Runnable {} */ void remove(); + DataTransformer[] pipelineTransformers(); + + Map pipelineTransformersHyperParams(); + } diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java index 568599a16633..834dd97693d8 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/PreprocessingStepDefinition.java @@ -1,6 +1,7 @@ package ai.h2o.automl.preprocessing; import ai.h2o.automl.AutoML; +import hex.pipeline.DataTransformer; import water.Iced; public class PreprocessingStepDefinition extends Iced { diff --git a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java index b242a97f0665..78a61d18d396 100644 --- a/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java +++ b/h2o-automl/src/main/java/ai/h2o/automl/preprocessing/TargetEncoding.java @@ -9,13 +9,17 @@ import ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy; import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; import ai.h2o.targetencoding.TargetEncoderPreprocessor; +import ai.h2o.targetencoding.pipeline.transformers.TargetEncoderFeatureTransformer; import hex.Model; import hex.Model.Parameters.FoldAssignmentScheme; import hex.ModelPreprocessor; +import hex.pipeline.DataTransformer; +import hex.pipeline.transformers.KFoldColumnGenerator; import water.DKV; import water.Key; import water.fvec.Frame; import water.fvec.Vec; +import water.nbhm.NonBlockingHashMap; import water.rapids.ast.prims.advmath.AstKFold; import water.util.ArrayUtils; @@ -200,6 +204,39 @@ TargetEncoderPreprocessor getTEPreprocessor() { TargetEncoderModel getTEModel() { return _teModel; } + + @Override + public DataTransformer[] pipelineTransformers() { + List dts = new ArrayList<>(); + TargetEncoderParameters teParams = (TargetEncoderParameters) getDefaultParams().clone(); + Frame train = _aml.getTrainingFrame(); + Set teColumns = selectColumnsToEncode(train, teParams); + if (teColumns.isEmpty()) return new DataTransformer[0]; + + String[] keep = teParams.getNonPredictors(); + teParams._ignored_columns = Arrays.stream(train.names()) + .filter(col -> !teColumns.contains(col) && !ArrayUtils.contains(keep, col)) + .toArray(String[]::new); + if (_aml.isCVEnabled()) { + dts.add(new KFoldColumnGenerator() + .id("add_fold_column") + .description("If cross-validation is enabled, generates (if needed) a fold column used by Target Encoder and for the final estimator")); + teParams._data_leakage_handling = DataLeakageHandlingStrategy.KFold; + } + dts.add(new TargetEncoderFeatureTransformer(teParams) + .id("default_TE") + .description("Applies Target Encoding to selected categorical features")); + return dts.toArray(new DataTransformer[0]); + } + + @Override + public Map pipelineTransformersHyperParams() { + Map hp = new HashMap<>(); + hp.put("default_TE._enabled", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); + hp.put("default_TE._keep_original_categorical_columns", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); + hp.put("default_TE._blending", new Boolean[] {Boolean.TRUE, Boolean.FALSE}); + return hp; + } private static void register(Frame fr, String keyPrefix, boolean force) { Key key = fr._key; @@ -210,10 +247,10 @@ private static void register(Frame fr, String keyPrefix, boolean force) { } public static Vec createFoldColumn(Frame fr, - FoldAssignmentScheme fold_assignment, - int nfolds, - String responseColumn, - long seed) { + FoldAssignmentScheme fold_assignment, + int nfolds, + String responseColumn, + long seed) { Vec foldColumn; switch (fold_assignment) { default: diff --git a/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java b/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java index bfc74f362d99..cf9fc621fcee 100644 --- a/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java +++ b/h2o-automl/src/main/java/water/automl/api/schemas3/AutoMLBuildSpecV99.java @@ -293,6 +293,10 @@ public static final class AutoMLBuildModelsV99 extends SchemaV3 { + + private String[] _columns; + private String _interaction_column; + + private String[] _interaction_domain; + + protected FeatureInteractionTransformer() {} + + public FeatureInteractionTransformer(String[] columns) { + this(columns, null); + } + + public FeatureInteractionTransformer(String[] columns, String interactionColumn) { + _columns = columns; + _interaction_column = interactionColumn; + } + + @Override + protected void doPrepare(PipelineContext context) { + assert context != null; + assert context._params != null; + Frame train = new Frame(context.getTrain()); + // FIXME: InteractionSupport should be improved to not systematically modify frames in-place + int interactionCol = InteractionSupport.addFeatureInteraction(train, _columns); + _interaction_domain = train.vec(interactionCol).domain(); + train.remove(interactionCol); + } + + @Override + protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { + InteractionSupport.addFeatureInteraction(fr, _columns, _interaction_domain); //FIXME: same as above. Also should be able to specify the interaction column name. + return fr; + } +} diff --git a/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java new file mode 100644 index 000000000000..b5b1a0680a86 --- /dev/null +++ b/h2o-extensions/target-encoder/src/main/java/ai/h2o/targetencoding/pipeline/transformers/TargetEncoderFeatureTransformer.java @@ -0,0 +1,64 @@ +package ai.h2o.targetencoding.pipeline.transformers; + +import ai.h2o.targetencoding.TargetEncoderModel; +import ai.h2o.targetencoding.TargetEncoderModel.TargetEncoderParameters; +import hex.Model; +import hex.pipeline.transformers.ModelAsFeatureTransformer; +import hex.pipeline.PipelineContext; +import water.Key; +import water.fvec.Frame; + +import static ai.h2o.targetencoding.TargetEncoderModel.DataLeakageHandlingStrategy.KFold; + +public class TargetEncoderFeatureTransformer extends ModelAsFeatureTransformer { + + public TargetEncoderFeatureTransformer(TargetEncoderParameters params) { + super(params); + } + + public TargetEncoderFeatureTransformer(TargetEncoderParameters params, Key modelKey) { + super(params, modelKey); + } + + @Override + public boolean isCVSensitive() { + return _params._data_leakage_handling == KFold; + } + + @Override + protected void prepareModelParams(PipelineContext context) { + super.prepareModelParams(context); + // TODO: future improvement: move some of the decision logic in `ai.h2o.automl.preprocessing.TargetEncoding` to this class + // especially the logic related with the dynamic column selection based on cardinality. + // By parametrizing this here, it allows us to consider parameters like `_columnCardinalityThreshold` as hyper-parameters in pipeline grids. + } + + @Override + protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) { + assert type != null; + assert context != null || type == FrameType.Test; + validateTransform(); + switch (type) { + case Training: + if (useFoldTransform(context._params)) { + return getModel().transformTraining(fr, context._params._cv_fold); + } else { + return getModel().transformTraining(fr); + } + case Validation: + if (useFoldTransform(context._params)) { + return getModel().transformTraining(fr); + } else { + return getModel().transform(fr); + } + case Test: + default: + return getModel().transform(fr); + } + } + + private boolean useFoldTransform(Model.Parameters params) { + return isCVSensitive() && params._cv_fold >= 0; + } + +}