17
17
import numpy as np
18
18
from sklearn .base import BaseEstimator
19
19
from sklearn .utils import check_array , gen_batches
20
+ from sklearn .utils .validation import _check_sample_weight
20
21
21
22
from daal4py .sklearn ._n_jobs_support import control_n_jobs
22
23
from daal4py .sklearn ._utils import sklearn_check_version
@@ -139,7 +140,7 @@ def _onedal_finalize_fit(self):
139
140
self ._onedal_estimator .finalize_fit ()
140
141
self ._need_to_finalize = False
141
142
142
- def _onedal_partial_fit (self , X , weights , queue ):
143
+ def _onedal_partial_fit (self , X , sample_weight = None , queue = None ):
143
144
first_pass = not hasattr (self , "n_samples_seen_" ) or self .n_samples_seen_ == 0
144
145
145
146
if sklearn_check_version ("1.0" ):
@@ -152,9 +153,11 @@ def _onedal_partial_fit(self, X, weights, queue):
152
153
X = check_array (
153
154
X ,
154
155
dtype = [np .float64 , np .float32 ],
155
- copy = self .copy_X ,
156
156
)
157
157
158
+ if sample_weight is not None :
159
+ sample_weight = _check_sample_weight (sample_weight , X )
160
+
158
161
if first_pass :
159
162
self .n_samples_seen_ = X .shape [0 ]
160
163
self .n_features_in_ = X .shape [1 ]
@@ -168,15 +171,18 @@ def _onedal_partial_fit(self, X, weights, queue):
168
171
self ._onedal_estimator = self ._onedal_incremental_basic_statistics (
169
172
** onedal_params
170
173
)
171
- self ._onedal_estimator .partial_fit (X , weights , queue )
174
+ self ._onedal_estimator .partial_fit (X , sample_weight , queue )
172
175
self ._need_to_finalize = True
173
176
174
- def _onedal_fit (self , X , weights , queue = None ):
177
+ def _onedal_fit (self , X , sample_weight = None , queue = None ):
175
178
if sklearn_check_version ("1.0" ):
176
179
X = self ._validate_data (X , dtype = [np .float64 , np .float32 ])
177
180
else :
178
181
X = check_array (X , dtype = [np .float64 , np .float32 ])
179
182
183
+ if sample_weight is not None :
184
+ sample_weight = _check_sample_weight (sample_weight , X )
185
+
180
186
n_samples , n_features = X .shape
181
187
if self .batch_size is None :
182
188
self .batch_size_ = 5 * n_features
@@ -189,7 +195,7 @@ def _onedal_fit(self, X, weights, queue=None):
189
195
190
196
for batch in gen_batches (X .shape [0 ], self .batch_size_ ):
191
197
X_batch = X [batch ]
192
- weights_batch = weights [batch ] if weights is not None else None
198
+ weights_batch = sample_weight [batch ] if sample_weight is not None else None
193
199
self ._onedal_partial_fit (X_batch , weights_batch , queue = queue )
194
200
195
201
if sklearn_check_version ("1.2" ):
@@ -217,7 +223,7 @@ def __getattr__(self, attr):
217
223
f"'{ self .__class__ .__name__ } ' object has no attribute '{ attr } '"
218
224
)
219
225
220
- def partial_fit (self , X , weights = None ):
226
+ def partial_fit (self , X , sample_weight = None ):
221
227
"""Incremental fit with X. All of X is processed as a single batch.
222
228
223
229
Parameters
@@ -226,7 +232,10 @@ def partial_fit(self, X, weights=None):
226
232
Data for compute, where `n_samples` is the number of samples and
227
233
`n_features` is the number of features.
228
234
229
- weights : array-like of shape (n_samples,)
235
+ y : Ignored
236
+ Not used, present for API consistency by convention.
237
+
238
+ sample_weight : array-like of shape (n_samples,), default=None
230
239
Weights for compute weighted statistics, where `n_samples` is the number of samples.
231
240
232
241
Returns
@@ -242,11 +251,11 @@ def partial_fit(self, X, weights=None):
242
251
"sklearn" : None ,
243
252
},
244
253
X ,
245
- weights ,
254
+ sample_weight ,
246
255
)
247
256
return self
248
257
249
- def fit (self , X , weights = None ):
258
+ def fit (self , X , y = None , sample_weight = None ):
250
259
"""Compute statistics with X, using minibatches of size batch_size.
251
260
252
261
Parameters
@@ -255,7 +264,10 @@ def fit(self, X, weights=None):
255
264
Data for compute, where `n_samples` is the number of samples and
256
265
`n_features` is the number of features.
257
266
258
- weights : array-like of shape (n_samples,)
267
+ y : Ignored
268
+ Not used, present for API consistency by convention.
269
+
270
+ sample_weight : array-like of shape (n_samples,), default=None
259
271
Weights for compute weighted statistics, where `n_samples` is the number of samples.
260
272
261
273
Returns
@@ -271,6 +283,6 @@ def fit(self, X, weights=None):
271
283
"sklearn" : None ,
272
284
},
273
285
X ,
274
- weights ,
286
+ sample_weight ,
275
287
)
276
288
return self
0 commit comments