Skip to content

Commit

Permalink
Fix java tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfryda committed Jan 26, 2024
1 parent 6b3b15c commit 5211045
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public void test_all_registered_steps() {
.collect(Collectors.toList());
ModelingStep[] modelingSteps = registry.getOrderedSteps(allSteps.toArray(new StepDefinition[0]), aml);
// 2 groups by default (1 for models, 1 for grids), hence the 2*2 SEs + 10 optional SEs
assertEquals((1/*completion*/)+(1+3/*DL*/) + (2/*DRF*/) + (5+1+1/*GBM*/) + (1/*GLM*/) + (2*2+10/*SE*/) + (3+1+2/*XGB*/),
assertEquals((1/*completion*/)+(1+3/*DL*/) + (2/*DRF*/) + (5+1+1/*GBM*/) + (1/*GLM*/) + (2*2+10/*SE*/) + (3+1+2/*XGB*/+1/*gblinear*/),
modelingSteps.length);
assertEquals(1, Stream.of(modelingSteps).filter(s -> "completion".equals(s.getProvider())).filter(ModelingStep.DynamicStep.class::isInstance).count());
assertEquals(1, Stream.of(modelingSteps).filter(s -> Algo.DeepLearning.name().equals(s.getProvider())).filter(ModelingStep.ModelStep.class::isInstance).count());
Expand All @@ -93,7 +93,7 @@ public void test_all_registered_steps() {
assertEquals(1, Stream.of(modelingSteps).filter(s -> Algo.GLM.name().equals(s.getProvider())).filter(ModelingStep.ModelStep.class::isInstance).count());
assertEquals(14, Stream.of(modelingSteps).filter(s -> Algo.StackedEnsemble.name().equals(s.getProvider())).filter(ModelingStep.ModelStep.class::isInstance).count());
assertEquals(3, Stream.of(modelingSteps).filter(s -> Algo.XGBoost.name().equals(s.getProvider())).filter(ModelingStep.ModelStep.class::isInstance).count());
assertEquals(1, Stream.of(modelingSteps).filter(s -> Algo.XGBoost.name().equals(s.getProvider())).filter(ModelingStep.GridStep.class::isInstance).count());
assertEquals(2, Stream.of(modelingSteps).filter(s -> Algo.XGBoost.name().equals(s.getProvider())).filter(ModelingStep.GridStep.class::isInstance).count());
assertEquals(2, Stream.of(modelingSteps).filter(s -> Algo.XGBoost.name().equals(s.getProvider())).filter(ModelingStep.SelectionStep.class::isInstance).count());

List<String> orderedStepIds = Arrays.stream(modelingSteps).flatMap(s -> Stream.of(s._provider, s._id)).collect(Collectors.toList());
Expand All @@ -112,7 +112,7 @@ public void test_all_registered_steps() {
Algo.DeepLearning.name(), "grid_1", Algo.DeepLearning.name(), "grid_2", Algo.DeepLearning.name(), "grid_3",
Algo.GBM.name(), "grid_1",
Algo.StackedEnsemble.name(), "best_of_family_2", Algo.StackedEnsemble.name(), "all_2",
Algo.XGBoost.name(), "grid_1",
Algo.XGBoost.name(), "grid_1", Algo.XGBoost.name(), "grid_gblinear",
Algo.GBM.name(), "lr_annealing",
Algo.StackedEnsemble.name(), "monotonic",
Algo.StackedEnsemble.name(), "best_of_family", Algo.StackedEnsemble.name(), "all",
Expand Down Expand Up @@ -158,7 +158,7 @@ public void test_all_grids() {
.toArray(StepDefinition[]::new);
ModelingStepsRegistry registry = new ModelingStepsRegistry();
ModelingStep[] modelingSteps = registry.getOrderedSteps(allGridSteps, aml);
assertEquals((3/*DL*/) + (1/*GBM*/) + (1/*XGB*/),
assertEquals((3/*DL*/) + (1/*GBM*/) + (1/*XGB*/+1/*gblinear*/),
modelingSteps.length);
}

Expand All @@ -173,7 +173,7 @@ public void test_all_defaults_plus_grids() {
ModelingStepsRegistry registry = new ModelingStepsRegistry();
ModelingStep[] modelingSteps = registry.getOrderedSteps(allGridSteps, aml);
// by default, 1 group for default models, 1 group for grids, hence the 2*2 SEs
assertEquals((1+3/*DL*/) + (2/*DRF*/) + (5+1/*GBM*/) + (1/*GLM*/) + (2*2/*SE*/) + (3+1/*XGB*/),
assertEquals((1+3/*DL*/) + (2/*DRF*/) + (5+1/*GBM*/) + (1/*GLM*/) + (2*2/*SE*/) + (3+1/*XGB*/+1/*gblinear*/),
modelingSteps.length);
}

Expand Down

0 comments on commit 5211045

Please sign in to comment.