Skip to content

Commit

Permalink
core pipeline API
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Jan 29, 2024
1 parent 5620485 commit 6eebae0
Show file tree
Hide file tree
Showing 118 changed files with 4,385 additions and 424 deletions.
37 changes: 25 additions & 12 deletions h2o-admissibleml/src/main/java/hex/Infogram/Infogram.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package hex.Infogram;

import hex.*;
import hex.Infogram.InfogramModel.InfogramModelOutput;
import hex.Infogram.InfogramModel.InfogramParameters;
import hex.ModelMetrics.MetricBuilder;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
Expand All @@ -18,8 +21,8 @@
import static water.util.ArrayUtils.sort;
import static water.util.ArrayUtils.sum;

public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, hex.Infogram.InfogramModel.InfogramParameters,
hex.Infogram.InfogramModel.InfogramModelOutput> {
public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, InfogramParameters,
InfogramModelOutput> {
static final double NORMALIZE_ADMISSIBLE_INDEX = 1.0/Math.sqrt(2.0);
boolean _buildCore; // true to find core predictors, false to find admissible predictors
String[] _topKPredictors; // contain the names of top predictors to consider for infogram
Expand All @@ -45,14 +48,14 @@ public class Infogram extends ModelBuilder<hex.Infogram.InfogramModel, hex.Infog
Model.Parameters.FoldAssignmentScheme _foldAssignmentOrig = null;
String _foldColumnOrig = null;

public Infogram(boolean startup_once) { super(new hex.Infogram.InfogramModel.InfogramParameters(), startup_once);}
public Infogram(boolean startup_once) { super(new InfogramParameters(), startup_once);}

public Infogram(hex.Infogram.InfogramModel.InfogramParameters parms) {
public Infogram(InfogramParameters parms) {
super(parms);
init(false);
}

public Infogram(hex.Infogram.InfogramModel.InfogramParameters parms, Key<hex.Infogram.InfogramModel> key) {
public Infogram(InfogramParameters parms, Key<hex.Infogram.InfogramModel> key) {
super(parms, key);
init(false);
}
Expand All @@ -71,18 +74,23 @@ protected int nModelsInParallel(int folds) {
* This is called before cross-validation is carried out
*/
@Override
public void computeCrossValidation() {
protected void cv_init() {
super.cv_init();
info("cross-validation", "cross-validation infogram information is stored in frame with key" +
" labeled as admissible_score_key_cv and the admissible features in admissible_features_cv.");
if (error_count() > 0) {
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(Infogram.this);
}
super.computeCrossValidation();
}

@Override
protected MetricBuilder makeCVMetricBuilder(ModelBuilder<InfogramModel, InfogramParameters, InfogramModelOutput> cvModelBuilder, Futures fs) {
return null; //infogram does not support scoring
}

// find the best alpha/lambda values used to build the main model moving forward by looking at the devianceValid
@Override
public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
int nBuilders = cvModelBuilders.length;
double[][] cmiRaw = new double[nBuilders][];
List<List<String>> columns = new ArrayList<>();
Expand All @@ -103,7 +111,12 @@ public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
}
_cvDone = true; // cv is done and we are going to build main model next
}


@Override
protected void cv_mainModelScores(int N, MetricBuilder[] mbs, ModelBuilder<InfogramModel, InfogramParameters, InfogramModelOutput>[] cvModelBuilders) {
//infogram does not support scoring
}

public void calculateMeanInfogramInfo(double[][] cmiRaw, List<List<String>> columns,
long[] nObs) {
int nFolds = cmiRaw.length;
Expand Down Expand Up @@ -304,7 +317,7 @@ public final void buildModel() {
try {
boolean validPresent = _parms.valid() != null;
prepareModelTrainingFrame(); // generate training frame with predictors and sensitive features (if specified)
InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new hex.Infogram.InfogramModel.InfogramModelOutput(Infogram.this));
InfogramModel model = new hex.Infogram.InfogramModel(dest(), _parms, new InfogramModelOutput(Infogram.this));
_model = model.delete_and_lock(_job);
_model._output._start_time = System.currentTimeMillis();
_cmiRaw = new double[_numModels];
Expand Down Expand Up @@ -359,7 +372,7 @@ public final void buildModel() {
* relevance >= relevance_threshold. Derive _admissible_index as distance from point with cmi = 1 and
* relevance = 1. In addition, all arrays are sorted on _admissible_index.
*/
private void copyCMIRelevance(InfogramModel.InfogramModelOutput modelOutput) {
private void copyCMIRelevance(InfogramModelOutput modelOutput) {
modelOutput._cmi_raw = new double[_cmi.length];
System.arraycopy(_cmiRaw, 0, modelOutput._cmi_raw, 0, modelOutput._cmi_raw.length);
modelOutput._admissible_index = new double[_cmi.length];
Expand All @@ -375,7 +388,7 @@ private void copyCMIRelevance(InfogramModel.InfogramModelOutput modelOutput) {
modelOutput._admissible_index, modelOutput._admissible, modelOutput._all_predictor_names);
}

public void copyCMIRelevanceValid(InfogramModel.InfogramModelOutput modelOutput) {
public void copyCMIRelevanceValid(InfogramModelOutput modelOutput) {
modelOutput._cmi_raw_valid = new double[_cmiValid.length];
System.arraycopy(_cmiRawValid, 0, modelOutput._cmi_raw_valid, 0, modelOutput._cmi_raw_valid.length);
modelOutput._admissible_index_valid = new double[_cmiValid.length];
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/deeplearning/DeepLearning.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static DataInfo makeDataInfo(Frame train, Frame valid, DeepLearningParameters pa
}
}

@Override public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
@Override protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
_parms._overwrite_with_best_model = false;

if( _parms._stopping_rounds == 0 && _parms._max_runtime_secs == 0) return; // No exciting changes to stopping conditions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
@Override
public ModelMetrics.MetricBuilder<?> getMetricBuilder() {
throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, " +
"retrieve your metrics by calling getOrMakeMetrics method.");
"retrieve your metrics by calling makeModelMetrics method.");
}
}

Expand Down
69 changes: 15 additions & 54 deletions h2o-algos/src/main/java/hex/glm/GLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import hex.util.LinearAlgebraUtils;
import hex.util.LinearAlgebraUtils.BMulTask;
import hex.util.LinearAlgebraUtils.FindMaxIndex;
import jsr166y.CountedCompleter;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.*;
Expand Down Expand Up @@ -119,7 +118,8 @@ public boolean isSupervised() {
public ModelCategory[] can_build() {
return new ModelCategory[]{
ModelCategory.Regression,
ModelCategory.Binomial,
ModelCategory.Binomial,
ModelCategory.Multinomial
};
}

Expand Down Expand Up @@ -148,13 +148,12 @@ public ModelCategory[] can_build() {
* (builds N+1 models, all have train+validation metrics, the main model has N-fold cross-validated validation metrics)
*/
@Override
public void computeCrossValidation() {
protected void cv_init() {
// init computes global list of lambdas
init(true);
_cvRuns = true;
if (error_count() > 0)
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GLM.this);
super.computeCrossValidation();
}


Expand Down Expand Up @@ -293,7 +292,7 @@ private double[] alignSubModelsAcrossCVModels(ModelBuilder[] cvModelBuilders) {
* 4. unlock the n-folds models (they are changed here, so the unlocking happens here)
*/
@Override
public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
protected void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
setMaxRuntimeSecsForMainModel();
double bestTestDev = Double.POSITIVE_INFINITY;
double[] alphasAndLambdas = alignSubModelsAcrossCVModels(cvModelBuilders);
Expand Down Expand Up @@ -372,12 +371,6 @@ public void cv_computeAndSetOptimalParameters(ModelBuilder[] cvModelBuilders) {
break;
}
}
for (int i = 0; i < cvModelBuilders.length; ++i) {
GLM g = (GLM) cvModelBuilders[i];
if (g._toRemove != null)
for (Key k : g._toRemove)
Keyed.remove(k);
}

for (int i = 0; i < cvModelBuilders.length; ++i) {
GLM g = (GLM) cvModelBuilders[i];
Expand Down Expand Up @@ -1543,11 +1536,11 @@ private void buildModel() {

protected static final long WORK_TOTAL = 1000000;

transient Key [] _toRemove;

private Key[] removeLater(Key ...k){
_toRemove = _toRemove == null?k:ArrayUtils.append(_toRemove,k);
return k;
@Override
protected void cleanUp() {
if (_parms._lambda_search && _parms._is_cv_model)
keepUntilCompletion(_dinfo.getWeightsVec()._key);
super.cleanUp();
}

@Override protected GLMDriver trainModelImpl() { return _driver = new GLMDriver(); }
Expand Down Expand Up @@ -1576,23 +1569,6 @@ public final class GLMDriver extends Driver implements ProgressMonitor {
private transient GLMTask.GLMIterationTask _gramInfluence;
private transient double[][] _cholInvInfluence;

private void doCleanup() {
try {
if (_parms._lambda_search && _parms._is_cv_model)
Scope.untrack(removeLater(_dinfo.getWeightsVec()._key));
if (_parms._HGLM) {
Key[] vecKeys = _toRemove;
for (int index = 0; index < vecKeys.length; index++) {
Vec tempVec = DKV.getGet(vecKeys[index]);
tempVec.remove();
}
}
} catch (Exception e) {
Log.err("Error while cleaning up GLM " + _result);
Log.err(e);
}
}

private transient Cholesky _chol;
private transient L1Solver _lslvr;

Expand Down Expand Up @@ -3564,9 +3540,8 @@ private Vec[] genGLMVectors(DataInfo dinfo, double[] nb) {
sumExp += Math.exp(nb[i * N + P] - maxRow);
}
Vec[] vecs = dinfo._adaptedFrame.anyVec().makeDoubles(2, new double[]{sumExp, maxRow});
if (_parms._lambda_search && _parms._is_cv_model) {
Scope.untrack(vecs[0]._key, vecs[1]._key);
removeLater(vecs[0]._key, vecs[1]._key);
if (_parms._lambda_search) {
track(vecs[0]); track(vecs[1]);
}
return vecs;
}
Expand Down Expand Up @@ -3848,7 +3823,7 @@ private void checkCoeffsBounds() {
* - column 2: zi, intermediate values
* - column 3: eta = X*beta, intermediate values
*/
public void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM
private void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response for HGLM
int moreColnum = 3 + _parms._random_columns.length;
Vec[] vecs = _dinfo._adaptedFrame.anyVec().makeZeros(moreColnum);
String[] colNames = new String[moreColnum];
Expand All @@ -3861,25 +3836,11 @@ public void addWdataZiEtaOld2Response() { // attach wdata, zi, eta to response f
vecs[index] = _parms.train().vec(randColIndices[index - 3]).makeCopy();
}
_dinfo.addResponse(colNames, vecs);
for (int index = 0; index < moreColnum; index++) {
Scope.untrack(vecs[index]._key);
removeLater(vecs[index]._key);
}
}

@Override
public void onCompletion(CountedCompleter caller) {
doCleanup();
super.onCompletion(caller);
Frame wdataZiEta = new Frame(Key.make("wdataZiEta"+Key.rand()), colNames, vecs);
DKV.put(wdataZiEta);
track(wdataZiEta);
}

@Override
public boolean onExceptionalCompletion(Throwable t, CountedCompleter caller) {
doCleanup();
return super.onExceptionalCompletion(t, caller);
}


@Override
public boolean progress(double[] beta, GradientInfo ginfo) {
_state._iter++;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/kmeans/KMeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ else if( user_points.numRows() != _parms._k)
if (expensive && error_count() == 0) checkMemoryFootPrint();
}

public void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){
protected void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] mbs){
super.cv_makeAggregateModelMetrics(mbs);
((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._within_sumsqe = null;
((ModelMetricsClustering.MetricBuilderClustering) mbs[0])._size = null;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ public double initialValue() {
return _parms._parallel_main_model_building;
}

@Override public void cv_computeAndSetOptimalParameters(ModelBuilder<M, P, O>[] cvModelBuilders) {
@Override protected void cv_computeAndSetOptimalParameters(ModelBuilder<M, P, O>[] cvModelBuilders) {
// Extract stopping conditions from each CV model, and compute the best stopping answer
if (!cv_initStoppingParameters())
return; // No exciting changes to stopping conditions
Expand Down
5 changes: 3 additions & 2 deletions h2o-bindings/bin/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def get_customizations_for(language, algo, property=None, default=None):
tokens = property.split('.')
value = customizations
for token in tokens:
value = value.get(token)
if value is None:
if token in value:
value = value.get(token)
else:
return default
return value
else:
Expand Down
19 changes: 19 additions & 0 deletions h2o-bindings/bin/custom/R/gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

extensions = dict(
required_params=[],
frame_params=[],
validate_required_params="",
set_required_params="",
module="""
.h2o.fill_pipeline<- function(model, parameters, allparams) {
if (!is.null(model$estimator)) {
model$estimator_model <- h2o.getModel(model$estimator$name)
} else {
model$estimator_model <- NULL
}
model$transformers <- unlist(lapply(model$transformers, function(dt) new("H2ODataTransformer", id=dt$id, description=dt$description)))
# class(model) <- "H2OPipeline"
return(model)
}
"""
)
56 changes: 56 additions & 0 deletions h2o-bindings/bin/custom/python/gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
supervised_learning = None # actually depends on the estimator model in the pipeline, leave it to None for now as it is needed only for training and we don't support pipeline as input yet


# in future update, we'll want to expose parameters applied to each transformer
def module_extensions():
class H2ODataTransformer(H2ODisplay):
@classmethod
def make(cls, kvs):
dt = H2ODataTransformer(**{k: v for k, v in kvs if k not in H2OSchema._ignored_schema_keys_})
dt._json = kvs
return dt

def __init__(self, id=None, description=None):
self._json = None
self.id = id
self.description = description

def _repr_(self):
return repr(self._json)

def _str_(self, verbosity=None):
return repr_def(self)


# self-register transformer class: done as soon as `h2o.estimators` is loaded, which means as soon as h2o.h2o is...
register_schema_handler("DataTransformerV3", H2ODataTransformer)


def class_extensions():
@property
def transformers(self):
return self._model_json['output']['transformers']

@property
def estimator_model(self):
m_json = self._model_json['output']['estimator']
return None if (m_json is None or m_json['name'] is None) else h2o.get_model(m_json['name'])

def transform(self, fr):
"""
Applies all the pipeline transformers to the given input frame.
:return: the transformed frame, as it would be passed to `estimator_model`, if calling `predict` instead.
"""
return H2OFrame._expr(expr=ExprNode("transform", ASTId(self.key), ASTId(fr.key)))._frame(fill_cache=True)


extensions = dict(
__imports__="""
import h2o
from h2o.display import H2ODisplay, repr_def
from h2o.expr import ASTId, ExprNode
from h2o.schemas import H2OSchema, register_schema_handler
""",
__class__=class_extensions,
__module__=module_extensions,
)
4 changes: 2 additions & 2 deletions h2o-bindings/bin/gen_R.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def get_schema_params(pname):
"verbose",
"destination_key"] # destination_key is only for SVD
bulk_params = list(zip(*filter(lambda t: not t[0] in bulk_pnames_skip, zip(sig_pnames, sig_params))))
bulk_pnames = list(bulk_params[0])
sig_bulk_params = list(bulk_params[1])
bulk_pnames = list(bulk_params[0]) if bulk_params else []
sig_bulk_params = list(bulk_params[1]) if bulk_params else []
sig_bulk_params.append("segment_columns = NULL")
sig_bulk_params.append("segment_models_id = NULL")
sig_bulk_params.append("parallelism = 1")
Expand Down
Loading

0 comments on commit 6eebae0

Please sign in to comment.