Skip to content

Commit

Permalink
GH-15993: Custom metric as a hyperparam for grid search (#15999)
Browse files Browse the repository at this point in the history
* Custom metric as a hyperparam for grid search

* Add test for custom_increasing

* Hide GAM's custom_metric_func until it's tested

* Revert changes in h2o-py/h2o/utils/shared_utils.py

* Fix python tests by hiding custom metrics when not calculated
  • Loading branch information
tomasfryda authored Jan 18, 2024
1 parent 1565818 commit 1639968
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 33 deletions.
1 change: 0 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/GAMV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public static final class GAMParametersV3 extends ModelParametersSchemaV3<GAMMod
"max_after_balance_size",
"max_confusion_matrix_size",
"max_runtime_secs",
"custom_metric_func",
"num_knots", // array: number of knots for each predictor
"spline_orders", // order of I-splines
"knot_ids", // string array storing frame keys that contains knot location
Expand Down
4 changes: 4 additions & 0 deletions h2o-core/src/main/java/hex/ModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,10 @@ private TwoDimTable makeCrossValidationSummaryTable(Key[] cvmodels) {
excluded.add("cm");
excluded.add("auc_obj");
excluded.add("aucpr");
if (null == _parms._custom_metric_func) { // hide custom metrics when not available
excluded.add("custom");
excluded.add("custom_increasing");
}
List<Method> methods = new ArrayList<>();
{
Model m = DKV.getGet(cvmodels[0]);
Expand Down
13 changes: 9 additions & 4 deletions h2o-core/src/main/java/hex/ModelMetrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,16 @@ protected StringBuilder appendToStringMetrics(StringBuilder sb) {

public final Model model() { return _model==null ? (_model=DKV.getGet(_modelKey)) : _model; }
public final Frame frame() { return _frame==null ? (_frame=DKV.getGet(_frameKey)) : _frame; }

public double custom() { return _custom_metric == null ? Double.NaN : _custom_metric.value; }
public double custom_increasing() { return _custom_metric == null ? Double.NaN : _custom_metric.value; } // same as custom but informs stopping criteria that higher is better

public double mse() { return _MSE; }
public double rmse() { return Math.sqrt(_MSE);}
public ConfusionMatrix cm() { return null; }
public float[] hr() { return null; }
public AUC2 auc_obj() { return null; }

public static ModelMetrics defaultModelMetrics(Model model) {
return model._output._cross_validation_metrics != null ? model._output._cross_validation_metrics
: model._output._validation_metrics != null ? model._output._validation_metrics
Expand All @@ -141,9 +144,7 @@ public static double getMetricFromModelMetric(ModelMetrics mm, String criterion)
criterion = criterion.toLowerCase();

if ("custom".equals(criterion)){
if (null == mm._custom_metric)
return Double.NaN;
return mm._custom_metric.value;
return mm.custom();
}

// Constructing confusion matrix based on criterion
Expand Down Expand Up @@ -273,6 +274,10 @@ public static Set<String> getAllowedMetrics(Key<Model> key) {
excluded.add("remove");
excluded.add("nobs");
if (m!=null) {
if (null == m._custom_metric) { // hide custom metrics when not available
excluded.add("custom");
excluded.add("custom_increasing");
}
for (Method meth : m.getClass().getMethods()) {
if (excluded.contains(meth.getName())) continue;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
``custom_metric_func``
----------------------

- Available in: GBM, DRF, Deeplearning, Stacked Ensembles, GLM, XGBoost
- Available in: GBM, GLM, DRF, Deeplearning, Stacked Ensembles, XGBoost
- Hyperparameter: no

Description
Expand Down
2 changes: 1 addition & 1 deletion h2o-docs/src/product/data-science/coxph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Common parameters

**NOTE**: In Flow, if you click the **Build a model** button from the ``Parse`` cell, the training frame is entered automatically.

- `use_all_factor_levels <algo-params/coxph.html>`__: Specify whether to use all factor levels in the possible set of predictors; if you enable this option, sufficient regularization is required. By default, the first factor level is skipped. This option defaults to ``True`` (enabled).
- `use_all_factor_levels <algo-params/use_all_factor_levels.html>`__: Specify whether to use all factor levels in the possible set of predictors; if you enable this option, sufficient regularization is required. By default, the first factor level is skipped. This option defaults to ``True`` (enabled).

- `weights_column <algo-params/weights_column.html>`__: Specify a column to use for the observation weights, which are used for bias correction. The specified ``weights_column`` must be included in the specified ``training_frame``.

Expand Down
19 changes: 0 additions & 19 deletions h2o-py/h2o/estimators/gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(self,
max_after_balance_size=5.0, # type: float
max_confusion_matrix_size=20, # type: int
max_runtime_secs=0.0, # type: float
custom_metric_func=None, # type: Optional[str]
num_knots=None, # type: Optional[List[int]]
spline_orders=None, # type: Optional[List[int]]
knot_ids=None, # type: Optional[List[str]]
Expand Down Expand Up @@ -329,9 +328,6 @@ def __init__(self,
:param max_runtime_secs: Maximum allowed runtime in seconds for model training. Use 0 to disable.
Defaults to ``0.0``.
:type max_runtime_secs: float
:param custom_metric_func: Reference to custom evaluation function, format: `language:keyName=funcName`
Defaults to ``None``.
:type custom_metric_func: str, optional
:param num_knots: Number of knots for gam predictors. If specified, must specify one for each gam predictor.
For monotone I-splines, mininum = 2, for cs spline, minimum = 3. For thin plate, minimum is size of
polynomial basis + 2.
Expand Down Expand Up @@ -438,7 +434,6 @@ def __init__(self,
self.max_after_balance_size = max_after_balance_size
self.max_confusion_matrix_size = max_confusion_matrix_size
self.max_runtime_secs = max_runtime_secs
self.custom_metric_func = custom_metric_func
self.num_knots = num_knots
self.spline_orders = spline_orders
self.knot_ids = knot_ids
Expand Down Expand Up @@ -1271,20 +1266,6 @@ def max_runtime_secs(self, max_runtime_secs):
assert_is_type(max_runtime_secs, None, numeric)
self._parms["max_runtime_secs"] = max_runtime_secs

@property
def custom_metric_func(self):
"""
Reference to custom evaluation function, format: `language:keyName=funcName`
Type: ``str``.
"""
return self._parms.get("custom_metric_func")

@custom_metric_func.setter
def custom_metric_func(self, custom_metric_func):
assert_is_type(custom_metric_func, None, str)
self._parms["custom_metric_func"] = custom_metric_func

@property
def num_knots(self):
"""
Expand Down
28 changes: 28 additions & 0 deletions h2o-py/tests/pyunit_utils/utils_model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@ def metric(self, l):
return l[0] / l[1]


class CustomRmseRegressionFunc:
def map(self, pred, act, w, o, model):
err = act[0] - pred[0]
return [err * err, 1]

def reduce(self, l, r):
return [l[0] + r[0], l[1] + r[1]]

def metric(self, l):
# Use Java API directly
import java.lang.Math as math
return math.sqrt(l[0] / l[1])


class CustomNegativeRmseRegressionFunc: # used to test custom_increasing
def map(self, pred, act, w, o, model):
err = act[0] - pred[0]
return [err * err, 1]

def reduce(self, l, r):
return [l[0] + r[0], l[1] + r[1]]

def metric(self, l):
# Use Java API directly
import java.lang.Math as math
return -math.sqrt(l[0] / l[1])


class CustomRmseFunc:
def map(self, pred, act, w, o, model):
idx = int(act[0])
Expand Down
70 changes: 70 additions & 0 deletions h2o-py/tests/testdir_algos/grid/pyunit_grid_custom_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import sys, os

sys.path.insert(1, os.path.join("..", "..", ".."))
import h2o
from tests import pyunit_utils
from collections import OrderedDict
from h2o.grid.grid_search import H2OGridSearch
from h2o.estimators.gbm import H2OGradientBoostingEstimator
from tests.pyunit_utils import CustomMaeFunc, CustomRmseRegressionFunc, CustomNegativeRmseRegressionFunc


def custom_mae_mm():
return h2o.upload_custom_metric(CustomMaeFunc, func_name="mae", func_file="mm_mae.py")


def custom_nrmse_mm():
return h2o.upload_custom_metric(CustomNegativeRmseRegressionFunc, func_name="nrmse", func_file="mm_nrmse.py")


def grid_custom_metric():
train = h2o.import_file(path=pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
# Run GBM Grid Search using Cross Validation with parallelization enabled
ntrees_opts = [1, 3, 5, 10]
hyper_parameters = OrderedDict()
hyper_parameters["ntrees"] = ntrees_opts
hyper_parameters["stopping_metric"] = "custom"
print("GBM grid with the following hyper_parameters:", hyper_parameters)

gs = H2OGridSearch(H2OGradientBoostingEstimator(custom_metric_func=custom_mae_mm()),
hyper_params=hyper_parameters,
parallelism=1)
gs.train(y=3, training_frame=train, nfolds=3)

assert len(gs.models) == 4
print(gs.get_grid(sort_by="rmse"))

print(gs.get_grid(sort_by="mae"))

# Should be ok - just one definition of custom metric
print(gs.get_grid(sort_by="custom"))


def grid_custom_increasing_metric():
train = h2o.import_file(path=pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
# Run GBM Grid Search using Cross Validation with parallelization enabled
ntrees_opts = [1, 3, 5, 10]
hyper_parameters = OrderedDict()
hyper_parameters["ntrees"] = ntrees_opts
hyper_parameters["stopping_metric"] = "custom_increasing"
print("GBM grid with the following hyper_parameters:", hyper_parameters)

gs = H2OGridSearch(H2OGradientBoostingEstimator(custom_metric_func=custom_nrmse_mm()),
hyper_params=hyper_parameters,
parallelism=1)
gs.train(y=3, training_frame=train, nfolds=3)

assert len(gs.models) == 4
print(gs.get_grid(sort_by="rmse"))

print(gs.get_grid(sort_by="mae"))

# Should be ok - just one definition of custom metric
print(gs.get_grid(sort_by="custom_increasing", decreasing=True))


if __name__ == "__main__":
pyunit_utils.run_tests([grid_custom_metric, grid_custom_increasing_metric])
else:
grid_custom_metric()
grid_custom_increasing_metric()
7 changes: 0 additions & 7 deletions h2o-r/h2o-package/R/gam.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@
#' @param max_after_balance_size Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires
#' balance_classes. Defaults to 5.0.
#' @param max_runtime_secs Maximum allowed runtime in seconds for model training. Use 0 to disable. Defaults to 0.
#' @param custom_metric_func Reference to custom evaluation function, format: `language:keyName=funcName`
#' @param num_knots Number of knots for gam predictors. If specified, must specify one for each gam predictor. For monotone
#' I-splines, mininum = 2, for cs spline, minimum = 3. For thin plate, minimum is size of polynomial basis + 2.
#' @param spline_orders Order of I-splines or NBSplineTypeI M-splines used for gam predictors. If specified, must be the same size as
Expand Down Expand Up @@ -211,7 +210,6 @@ h2o.gam <- function(x,
class_sampling_factors = NULL,
max_after_balance_size = 5.0,
max_runtime_secs = 0,
custom_metric_func = NULL,
num_knots = NULL,
spline_orders = NULL,
knot_ids = NULL,
Expand Down Expand Up @@ -360,8 +358,6 @@ h2o.gam <- function(x,
parms$max_after_balance_size <- max_after_balance_size
if (!missing(max_runtime_secs))
parms$max_runtime_secs <- max_runtime_secs
if (!missing(custom_metric_func))
parms$custom_metric_func <- custom_metric_func
if (!missing(num_knots))
parms$num_knots <- num_knots
if (!missing(spline_orders))
Expand Down Expand Up @@ -470,7 +466,6 @@ h2o.gam <- function(x,
class_sampling_factors = NULL,
max_after_balance_size = 5.0,
max_runtime_secs = 0,
custom_metric_func = NULL,
num_knots = NULL,
spline_orders = NULL,
knot_ids = NULL,
Expand Down Expand Up @@ -624,8 +619,6 @@ h2o.gam <- function(x,
parms$max_after_balance_size <- max_after_balance_size
if (!missing(max_runtime_secs))
parms$max_runtime_secs <- max_runtime_secs
if (!missing(custom_metric_func))
parms$custom_metric_func <- custom_metric_func
if (!missing(num_knots))
parms$num_knots <- num_knots
if (!missing(spline_orders))
Expand Down

0 comments on commit 1639968

Please sign in to comment.