Skip to content

Commit 14279a8

Browse files
committed
Release
1 parent afef276 commit 14279a8

File tree

4 files changed

+26
-20
lines changed

4 files changed

+26
-20
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
HERE = pathlib.Path(__file__).parent
55

6-
VERSION = '0.2.4'
6+
VERSION = '0.2.5'
77
PACKAGE_NAME = 'shap-hypetune'
88
AUTHOR = 'Marco Cerliani'
99
AUTHOR_EMAIL = '[email protected]'

shaphypetune/_classes.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ def fit(self, X, y, trials=None, **fit_params):
181181
fn=lambda p: self._fit(
182182
params=p, X=X, y=y, fit_params=fit_params
183183
),
184-
space=self.param_grid, algo=tpe.suggest,
184+
space=self._param_combi, algo=tpe.suggest,
185185
max_evals=self.n_iter, trials=trials,
186186
rstate=np.random.RandomState(self.sampling_seed),
187187
show_progressbar=False, verbose=0
188188
)
189-
all_results = sorted(trials.results, key=lambda x: x['loss'])
189+
all_results = trials.results
190190

191191
else:
192192
all_results = Parallel(
@@ -200,10 +200,10 @@ def fit(self, X, y, trials=None, **fit_params):
200200
self.trials_.append(job_res['params'])
201201
self.iterations_.append(job_res['iterations'])
202202
self.scores_.append(self._score_sign * job_res['loss'])
203-
if job_res['model'] is None:
204-
models.append(job_res['booster'])
205-
else:
203+
if isinstance(job_res['model'], _BoostSelector):
206204
models.append(job_res['model'])
205+
else:
206+
models.append(job_res['booster'])
207207

208208
# get the best
209209
id_best = self._eval_score(self.scores_)
@@ -401,12 +401,12 @@ def _check_fit_params(self, fit_params, feat_id_real=None):
401401
self.support_, self._cat_support, _fit_params, duplicate=True)
402402

403403
if feat_id_real is None: # final model fit
404-
if 'eval_set' in fit_params:
404+
if 'eval_set' in _fit_params:
405405
_fit_params['eval_set'] = list(map(lambda x: (
406406
self.transform(x[0]), x[1]
407407
), _fit_params['eval_set']))
408408
else:
409-
if 'eval_set' in fit_params: # iterative model fit
409+
if 'eval_set' in _fit_params: # iterative model fit
410410
_fit_params['eval_set'] = list(map(lambda x: (
411411
self._create_X(x[0], feat_id_real), x[1]
412412
), _fit_params['eval_set']))
@@ -627,7 +627,7 @@ def _check_fit_params(self, fit_params):
627627
_fit_params = _set_categorical_indexes(
628628
self.support_, self._cat_support, _fit_params)
629629

630-
if 'eval_set' in fit_params:
630+
if 'eval_set' in _fit_params:
631631
_fit_params['eval_set'] = list(map(lambda x: (
632632
self.transform(x[0]), x[1]
633633
), _fit_params['eval_set']))
@@ -809,7 +809,7 @@ def _check_fit_params(self, fit_params, inverse=False):
809809
_fit_params = _set_categorical_indexes(
810810
self.support_, self._cat_support, _fit_params)
811811

812-
if 'eval_set' in fit_params:
812+
if 'eval_set' in _fit_params:
813813
_fit_params['eval_set'] = list(map(lambda x: (
814814
self._transform(x[0], inverse), x[1]
815815
), _fit_params['eval_set']))
@@ -956,7 +956,7 @@ def fit(self, X, y, **fit_params):
956956
with contextlib.redirect_stdout(io.StringIO()):
957957
self.estimator_.fit(self._transform(X, inverse=False), y, **_fit_params)
958958

959-
# compute step score when only min_features_to_select features left
959+
# compute step score when only min_features_to_select features left
960960
if scoring:
961961
score = self._step_score(self.estimator_)
962962
self.score_history_.append(score)

shaphypetune/shaphypetune.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,11 @@ def __init__(self,
295295
def _build_model(self, params=None):
296296
"""Private method to build model."""
297297

298+
estimator = clone(self.estimator)
299+
298300
if params is None:
299301
model = _Boruta(
300-
estimator=self.estimator,
302+
estimator=estimator,
301303
perc=self.perc,
302304
alpha=self.alpha,
303305
max_iter=self.max_iter,
@@ -308,9 +310,7 @@ def _build_model(self, params=None):
308310
)
309311

310312
else:
311-
estimator = clone(self.estimator)
312313
estimator.set_params(**params)
313-
314314
model = _Boruta(
315315
estimator=estimator,
316316
perc=self.perc,
@@ -499,9 +499,11 @@ def __init__(self,
499499
def _build_model(self, params=None):
500500
"""Private method to build model."""
501501

502+
estimator = clone(self.estimator)
503+
502504
if params is None:
503505
model = _RFE(
504-
estimator=self.estimator,
506+
estimator=estimator,
505507
min_features_to_select=self.min_features_to_select,
506508
step=self.step,
507509
greater_is_better=self.greater_is_better,
@@ -511,9 +513,7 @@ def _build_model(self, params=None):
511513
)
512514

513515
else:
514-
estimator = clone(self.estimator)
515516
estimator.set_params(**params)
516-
517517
model = _RFE(
518518
estimator=estimator,
519519
min_features_to_select=self.min_features_to_select,
@@ -706,9 +706,11 @@ def __init__(self,
706706
def _build_model(self, params=None):
707707
"""Private method to build model."""
708708

709+
estimator = clone(self.estimator)
710+
709711
if params is None:
710712
model = _RFA(
711-
estimator=self.estimator,
713+
estimator=estimator,
712714
min_features_to_select=self.min_features_to_select,
713715
step=self.step,
714716
greater_is_better=self.greater_is_better,
@@ -718,9 +720,7 @@ def _build_model(self, params=None):
718720
)
719721

720722
else:
721-
estimator = clone(self.estimator)
722723
estimator.set_params(**params)
723-
724724
model = _RFA(
725725
estimator=estimator,
726726
min_features_to_select=self.min_features_to_select,

shaphypetune/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def sample(self):
174174
is_random = all(isinstance(p, list) or 'scipy' in str(type(p)).lower()
175175
for p in param_distributions.values())
176176
is_hyperopt = all('hyperopt' in str(type(p)).lower()
177+
or (len(p) < 2 if isinstance(p, list) else False)
177178
for p in param_distributions.values())
178179

179180
if is_grid:
@@ -219,6 +220,11 @@ def sample(self):
219220
"n_iter must be an integer >0 when hyperopt "
220221
"search spaces are provided. Get None."
221222
)
223+
param_distributions = {
224+
k: p[0] if isinstance(p, list) else p
225+
for k, p in param_distributions.items()
226+
}
227+
222228
return param_distributions, 'hyperopt'
223229

224230
else:

0 commit comments

Comments
 (0)