Skip to content

Commit

Permalink
fix hyperparameter setting from schema
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Feb 2, 2024
1 parent cfeb762 commit 3f0417e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion h2o-core/src/main/java/hex/faulttolerance/Recovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void autoRecover() {
Grid grid = Grid.importBinary(recoveryFile(resultKey), true);
GridSearch.resumeGridSearch(
jobKey, grid,
new GridSearchHandler.APIModelParametersBuilderFactory(),
new GridSearchHandler.SchemaModelParametersBuilderFactory(),
(Recovery<Grid>) this
);
} else {
Expand Down
14 changes: 8 additions & 6 deletions h2o-core/src/main/java/water/api/GridSearchHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private S resumeGrid(String algoURLName, Properties parms) {
Recovery<Grid> recovery = getRecovery(gss);
Job<Grid> gsJob = GridSearch.resumeGridSearch(
jobKey, grid,
new APIModelParametersBuilderFactory<MP, P>(),
new SchemaModelParametersBuilderFactory<MP, P>(),
recovery
);
gss.hyper_parameters = null;
Expand Down Expand Up @@ -138,7 +138,7 @@ private S trainGrid(String algoURLName, Properties parms) {
destKey,
params,
sortedMap,
new APIModelParametersBuilderFactory<MP, P>(),
new SchemaModelParametersBuilderFactory<MP, P>(),
(HyperSpaceSearchCriteria) gss.search_criteria.createAndFillImpl(),
recovery,
GridSearch.getParallelismLevel(gss.parallelism)
Expand Down Expand Up @@ -204,7 +204,7 @@ private Recovery<Grid> getRecovery(GridSearchSchema gss) {
}
}

public static class APIModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
public static class SchemaModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
implements ModelParametersBuilderFactory<MP> {

@Override
Expand All @@ -214,7 +214,7 @@ public ModelParametersBuilder<MP> get(MP initialParams) {

@Override
public PojoUtils.FieldNaming getFieldNamingStrategy() {
return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;
return ModelParametersFromSchemaBuilder.NAMING;
}
}

Expand All @@ -229,6 +229,8 @@ public PojoUtils.FieldNaming getFieldNamingStrategy() {
*/
public static class ModelParametersFromSchemaBuilder<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {

private final static PojoUtils.FieldNaming NAMING = PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;

final private MP params;
final private PS paramsSchema;
Expand All @@ -242,7 +244,7 @@ public ModelParametersFromSchemaBuilder(MP initialParams) {

@Override
public boolean isAssignable(String name) {
return params.isParameterAssignable(name);
return params.isParameterAssignable(NAMING.toDest(name));
}

public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {
Expand All @@ -262,7 +264,7 @@ public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {

public MP build() {
PojoUtils
.copyProperties(params, paramsSchema, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null,
.copyProperties(params, paramsSchema, NAMING, null,
fields.toArray(new String[fields.size()]));
// FIXME: handle these train/valid fields in different way
// See: ModelParametersSchemaV3#fillImpl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import water.Key;
import water.Scope;
import water.TestUtil;
import water.api.GridSearchHandler.APIModelParametersBuilderFactory;
import water.api.GridSearchHandler.SchemaModelParametersBuilderFactory;
import water.fvec.Frame;

import java.util.Arrays;
Expand Down Expand Up @@ -49,7 +49,7 @@ public void getTargetEncodingMapByTrainingTEBuilder() {

TargetEncoderParameters parameters = new TargetEncoderParameters();

APIModelParametersBuilderFactory<TargetEncoderParameters, TargetEncoderParametersV3> modelParametersBuilderFactory = new APIModelParametersBuilderFactory<>();
SchemaModelParametersBuilderFactory<TargetEncoderParameters, TargetEncoderParametersV3> modelParametersBuilderFactory = new SchemaModelParametersBuilderFactory<>();

RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new RandomDiscreteValueSearchCriteria();
RandomDiscreteValueWalker<TargetEncoderParameters> walker = new RandomDiscreteValueWalker<>(parameters, hpGrid, modelParametersBuilderFactory, hyperSpaceSearchCriteria);
Expand Down Expand Up @@ -105,8 +105,8 @@ public void regularGSOverTEParameters_parallel() {
parameters._response_column = responseColumn;
parameters._ignored_columns = ignoredColumns(trainingFrame, "home.dest", "embarked", parameters._response_column);

APIModelParametersBuilderFactory<TargetEncoderParameters, TargetEncoderParametersV3> modelParametersBuilderFactory =
new APIModelParametersBuilderFactory<>();
SchemaModelParametersBuilderFactory<TargetEncoderParameters, TargetEncoderParametersV3> modelParametersBuilderFactory =
new SchemaModelParametersBuilderFactory<>();

RandomDiscreteValueSearchCriteria hyperSpaceSearchCriteria = new RandomDiscreteValueSearchCriteria();
RandomDiscreteValueWalker<TargetEncoderParameters> walker = new RandomDiscreteValueWalker<>(
Expand Down

0 comments on commit 3f0417e

Please sign in to comment.