Skip to content

Commit

Permalink
GH-16063 UpliftDRF - Add error in the explain method (#16064)
Browse files Browse the repository at this point in the history
Add error in exmplain for UpliftDRF models
  • Loading branch information
maurever authored Feb 9, 2024
1 parent 72ae632 commit 8d4b331
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
13 changes: 13 additions & 0 deletions h2o-py/h2o/explanation/_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,17 @@ def _get_xy(model):
return x, y


def _get_treatment(model):
# type: (h2o.model.ModelBase) -> str
"""
Get treatment column.
:param model: H2O Model
:returns: treatment column name
"""
treat = model.actual_params.get("treatment_column")
return treat


def _consolidate_varimps(model):
# type (h2o.model.ModelBase) -> Dict
"""
Expand Down Expand Up @@ -3037,6 +3048,8 @@ def _process_models_input(
models_with_varimp = [model for model in models if _has_varimp(model)]
tree_models_to_show = _get_tree_models(models, 1 if is_aml else float("inf"))
y = _get_xy(models_to_show[0])[1]
if any(_get_treatment(x) is not None for x in models_to_show):
raise ValueError("Uplift models currently cannot be used with explain function.")
classification = frame[y].isfactor()[0]
multinomial_classification = classification and frame[y].nlevels()[0] > 2
targets = [None]
Expand Down
53 changes: 53 additions & 0 deletions h2o-py/tests/testdir_algos/uplift/pyunit_uplift_rf_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import sys, os

sys.path.insert(1, os.path.join("..", "..", ".."))
import h2o
from tests import pyunit_utils
from h2o.estimators import H2OUpliftRandomForestEstimator


def uplift_random_forest_explain():
print("Uplift Distributed Random Forest explain test")
seed = 12345

treatment_column = "treatment"
response_column = "outcome"
x_names = ["feature_"+str(x) for x in range(1, 3)]

train_h2o = h2o.upload_file(pyunit_utils.locate("smalldata/uplift/upliftml_train.csv"))
train_h2o[treatment_column] = train_h2o[treatment_column].asfactor()
train_h2o[response_column] = train_h2o[response_column].asfactor()

valid_h2o = h2o.upload_file(pyunit_utils.locate("smalldata/uplift/upliftml_test.csv"))
valid_h2o[treatment_column] = valid_h2o[treatment_column].asfactor()
valid_h2o[response_column] = valid_h2o[response_column].asfactor()

ntrees = 2
max_depth = 2
min_rows = 10
sample_rate = 0.8

uplift_model = H2OUpliftRandomForestEstimator(
ntrees=ntrees,
max_depth=max_depth,
treatment_column=treatment_column,
min_rows=min_rows,
seed=seed,
sample_rate=sample_rate,
score_each_iteration=True
)

uplift_model.train(y=response_column, x=x_names, training_frame=train_h2o, validation_frame=valid_h2o)
print(uplift_model)

# should throw error
try:
uplift_model.explain(valid_h2o)
except ValueError:
assert True, "The explain function should fail with UpliftDRF."


if __name__ == "__main__":
pyunit_utils.standalone_test(uplift_random_forest_explain)
else:
uplift_random_forest_explain()
6 changes: 4 additions & 2 deletions h2o-r/h2o-package/R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ case_insensitive_match_arg <- function(arg, choices) {
.self
},
get_model = function(model_id) {
return(memoised_models$get_model(model_id))
model <- memoised_models$get_model(model_id)
if (!is.null(model@allparameters$treatment_column))
stop("Uplift models have not supported in explain yet.")
return(model)
}
)
)
Expand Down Expand Up @@ -314,7 +317,6 @@ case_insensitive_match_arg <- function(arg, choices) {
object$model_ids <- head(object$model_ids, n = min(top_n_from_AutoML, length(object$model_ids)))
}
}

return(object)
}

Expand Down
39 changes: 39 additions & 0 deletions h2o-r/tests/testdir_algos/uplift/runit_uplit_rf_explain.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
setwd(normalizePath(dirname(R.utils::commandArgs(asValues = TRUE)$"f")))
source("../../../scripts/h2o-r-test-setup.R")


test.uplift <- function() {
ntrees <- 2
max_depth <- 2
min_rows <- 10
sample_rate <- 0.8
seed <- 42
set.seed(seed)
x <- c("feature_1", "feature_2", "feature_3")
y <- "outcome"
treatment_col <- "treatment"

# Test data preparation for each implementation
train <- h2o.importFile(path=locate("smalldata/uplift/upliftml_train.csv"),
col.types=list(by.col.name=c(treatment_col, y), types=c("factor", "factor")))
test <- h2o.importFile(path=locate("smalldata/uplift/upliftml_test.csv"),
col.types=list(by.col.name=c(treatment_col, y), types=c("factor", "factor")))

model <- h2o.upliftRandomForest(
x = x,
y = y,
training_frame = train,
validation_frame = test,
treatment_column = treatment_col,
ntrees = ntrees,
max_depth = max_depth,
min_rows = min_rows,
sample_rate = sample_rate,
score_each_iteration=TRUE,
seed = seed)

print(model)
expect_error(h2o.explain(model, test))
}

doTest("Uplift Random Forest Test: Test H2O RF uplift", test.uplift)

0 comments on commit 8d4b331

Please sign in to comment.