Skip to content

Commit 68ee7ab

Browse files
authored
[bug] fix sample_weight check for IncrementalBasicStatistics (#1799)
* Update incremental_basic_statistics.py * formatting' * Update incremental_basic_statistics.py * Update test_incremental_basic_statistics.py * Update incremental_basic_statistics.py * Update incremental_basic_statistics.py * Update incremental_basic_statistics.py * Update incremental_basic_statistics.py * Update incremental_basic_statistics.py * Update test_incremental_basic_statistics.py
1 parent 3925eef commit 68ee7ab

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

sklearnex/basic_statistics/incremental_basic_statistics.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from sklearn.base import BaseEstimator
1919
from sklearn.utils import check_array, gen_batches
20+
from sklearn.utils.validation import _check_sample_weight
2021

2122
from daal4py.sklearn._n_jobs_support import control_n_jobs
2223
from daal4py.sklearn._utils import sklearn_check_version
@@ -139,7 +140,7 @@ def _onedal_finalize_fit(self):
139140
self._onedal_estimator.finalize_fit()
140141
self._need_to_finalize = False
141142

142-
def _onedal_partial_fit(self, X, weights, queue):
143+
def _onedal_partial_fit(self, X, sample_weight=None, queue=None):
143144
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0
144145

145146
if sklearn_check_version("1.0"):
@@ -152,9 +153,11 @@ def _onedal_partial_fit(self, X, weights, queue):
152153
X = check_array(
153154
X,
154155
dtype=[np.float64, np.float32],
155-
copy=self.copy_X,
156156
)
157157

158+
if sample_weight is not None:
159+
sample_weight = _check_sample_weight(sample_weight, X)
160+
158161
if first_pass:
159162
self.n_samples_seen_ = X.shape[0]
160163
self.n_features_in_ = X.shape[1]
@@ -168,15 +171,18 @@ def _onedal_partial_fit(self, X, weights, queue):
168171
self._onedal_estimator = self._onedal_incremental_basic_statistics(
169172
**onedal_params
170173
)
171-
self._onedal_estimator.partial_fit(X, weights, queue)
174+
self._onedal_estimator.partial_fit(X, sample_weight, queue)
172175
self._need_to_finalize = True
173176

174-
def _onedal_fit(self, X, weights, queue=None):
177+
def _onedal_fit(self, X, sample_weight=None, queue=None):
175178
if sklearn_check_version("1.0"):
176179
X = self._validate_data(X, dtype=[np.float64, np.float32])
177180
else:
178181
X = check_array(X, dtype=[np.float64, np.float32])
179182

183+
if sample_weight is not None:
184+
sample_weight = _check_sample_weight(sample_weight, X)
185+
180186
n_samples, n_features = X.shape
181187
if self.batch_size is None:
182188
self.batch_size_ = 5 * n_features
@@ -189,7 +195,7 @@ def _onedal_fit(self, X, weights, queue=None):
189195

190196
for batch in gen_batches(X.shape[0], self.batch_size_):
191197
X_batch = X[batch]
192-
weights_batch = weights[batch] if weights is not None else None
198+
weights_batch = sample_weight[batch] if sample_weight is not None else None
193199
self._onedal_partial_fit(X_batch, weights_batch, queue=queue)
194200

195201
if sklearn_check_version("1.2"):
@@ -217,7 +223,7 @@ def __getattr__(self, attr):
217223
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
218224
)
219225

220-
def partial_fit(self, X, weights=None):
226+
def partial_fit(self, X, sample_weight=None):
221227
"""Incremental fit with X. All of X is processed as a single batch.
222228
223229
Parameters
@@ -226,7 +232,10 @@ def partial_fit(self, X, weights=None):
226232
Data for compute, where `n_samples` is the number of samples and
227233
`n_features` is the number of features.
228234
229-
weights : array-like of shape (n_samples,)
235+
y : Ignored
236+
Not used, present for API consistency by convention.
237+
238+
sample_weight : array-like of shape (n_samples,), default=None
230239
Weights for compute weighted statistics, where `n_samples` is the number of samples.
231240
232241
Returns
@@ -242,11 +251,11 @@ def partial_fit(self, X, weights=None):
242251
"sklearn": None,
243252
},
244253
X,
245-
weights,
254+
sample_weight,
246255
)
247256
return self
248257

249-
def fit(self, X, weights=None):
258+
def fit(self, X, y=None, sample_weight=None):
250259
"""Compute statistics with X, using minibatches of size batch_size.
251260
252261
Parameters
@@ -255,7 +264,10 @@ def fit(self, X, weights=None):
255264
Data for compute, where `n_samples` is the number of samples and
256265
`n_features` is the number of features.
257266
258-
weights : array-like of shape (n_samples,)
267+
y : Ignored
268+
Not used, present for API consistency by convention.
269+
270+
sample_weight : array-like of shape (n_samples,), default=None
259271
Weights for compute weighted statistics, where `n_samples` is the number of samples.
260272
261273
Returns
@@ -271,6 +283,6 @@ def fit(self, X, weights=None):
271283
"sklearn": None,
272284
},
273285
X,
274-
weights,
286+
sample_weight,
275287
)
276288
return self

sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_partial_fit_multiple_options_on_gold_data(dataframe, queue, weighted, d
5252
weights_split_df = _convert_to_dataframe(
5353
weights_split[i], sycl_queue=queue, target_df=dataframe
5454
)
55-
result = incbs.partial_fit(X_split_df, weights_split_df)
55+
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
5656
else:
5757
result = incbs.partial_fit(X_split_df)
5858

@@ -103,7 +103,7 @@ def test_partial_fit_single_option_on_random_data(
103103
weights_split_df = _convert_to_dataframe(
104104
weights_split[i], sycl_queue=queue, target_df=dataframe
105105
)
106-
result = incbs.partial_fit(X_split_df, weights_split_df)
106+
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
107107
else:
108108
result = incbs.partial_fit(X_split_df)
109109

@@ -146,7 +146,7 @@ def test_partial_fit_multiple_options_on_random_data(
146146
weights_split_df = _convert_to_dataframe(
147147
weights_split[i], sycl_queue=queue, target_df=dataframe
148148
)
149-
result = incbs.partial_fit(X_split_df, weights_split_df)
149+
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
150150
else:
151151
result = incbs.partial_fit(X_split_df)
152152

@@ -199,7 +199,7 @@ def test_partial_fit_all_option_on_random_data(
199199
weights_split_df = _convert_to_dataframe(
200200
weights_split[i], sycl_queue=queue, target_df=dataframe
201201
)
202-
result = incbs.partial_fit(X_split_df, weights_split_df)
202+
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
203203
else:
204204
result = incbs.partial_fit(X_split_df)
205205

@@ -233,7 +233,7 @@ def test_fit_multiple_options_on_gold_data(dataframe, queue, weighted, dtype):
233233
incbs = IncrementalBasicStatistics(batch_size=1)
234234

235235
if weighted:
236-
result = incbs.fit(X_df, weights_df)
236+
result = incbs.fit(X_df, sample_weight=weights_df)
237237
else:
238238
result = incbs.fit(X_df)
239239

@@ -272,15 +272,15 @@ def test_fit_single_option_on_random_data(
272272
X = X.astype(dtype=dtype)
273273
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
274274
if weighted:
275-
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
275+
weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
276276
weights = weights.astype(dtype=dtype)
277277
weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
278278
incbs = IncrementalBasicStatistics(
279279
result_options=result_option, batch_size=batch_size
280280
)
281281

282282
if weighted:
283-
result = incbs.fit(X_df, weights_df)
283+
result = incbs.fit(X_df, sample_weight=weights_df)
284284
else:
285285
result = incbs.fit(X_df)
286286

@@ -311,15 +311,15 @@ def test_partial_fit_multiple_options_on_random_data(
311311
X = X.astype(dtype=dtype)
312312
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
313313
if weighted:
314-
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
314+
weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
315315
weights = weights.astype(dtype=dtype)
316316
weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
317317
incbs = IncrementalBasicStatistics(
318318
result_options=["mean", "max", "sum"], batch_size=batch_size
319319
)
320320

321321
if weighted:
322-
result = incbs.fit(X_df, weights_df)
322+
result = incbs.fit(X_df, sample_weight=weights_df)
323323
else:
324324
result = incbs.fit(X_df)
325325

@@ -366,7 +366,7 @@ def test_fit_all_option_on_random_data(
366366
incbs = IncrementalBasicStatistics(result_options="all", batch_size=batch_size)
367367

368368
if weighted:
369-
result = incbs.fit(X_df, weights_df)
369+
result = incbs.fit(X_df, sample_weight=weights_df)
370370
else:
371371
result = incbs.fit(X_df)
372372

0 commit comments

Comments
 (0)