Skip to content

Commit 29871eb

Browse files
authored
Improve accuracy_score compatibility with sklearn. (#6406)
This adds support for the `sample_weight` and `normalize` arguments to `accuracy_score`. It also adds support for non-numeric dtypes (like strings), which are support by sklearn in this API. Like with the recent changes to `r2_score`, this is done by moving the implementation to use `cupy` instead of calling a C++ function. `accuracy_score` is not a performance critical algorithm, and the `cupy` implementation is both easier to maintain and should be good enough. Also adds support for `sample_weight` in `ClassifierMixin.score`, fixing a sklearn compatibility bug that was affecting the 0cc layer. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Simon Adorf (https://github.com/csadorf) - Bradley Dice (https://github.com/bdice) URL: #6406
1 parent c6a03b7 commit 29871eb

File tree

11 files changed

+222
-162
lines changed

11 files changed

+222
-162
lines changed

docs/source/api.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ Metrics (regression, classification, and distance)
246246
.. automodule:: cuml.metrics.regression
247247
:members:
248248

249-
.. automodule:: cuml.metrics.accuracy
250-
:members:
249+
.. autofunction:: cuml.metrics.accuracy_score
251250

252251
.. autofunction:: cuml.metrics.confusion_matrix
253252

python/cuml/cuml/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@
6262
from cuml.linear_model.mbsgd_regressor import MBSGDRegressor
6363

6464
from cuml.manifold.t_sne import TSNE
65-
from cuml.metrics.accuracy import accuracy_score
66-
from cuml.metrics.cluster.adjusted_rand_index import adjusted_rand_score
67-
from cuml.metrics.regression import r2_score
65+
from cuml.metrics import accuracy_score, r2_score, adjusted_rand_score
6866
from cuml.model_selection import train_test_split
6967

7068
from cuml.naive_bayes.naive_bayes import MultinomialNB

python/cuml/cuml/experimental/hyperparams/HPO_demo.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
"\n",
8888
"from cuml.neighbors import KNeighborsClassifier\n",
8989
"from cuml.preprocessing.model_selection import train_test_split\n",
90-
"from cuml.metrics.accuracy import accuracy_score\n",
90+
"from cuml.metrics import accuracy_score\n",
9191
"\n",
9292
"import os\n",
9393
"from urllib.request import urlretrieve\n",
@@ -375,10 +375,10 @@
375375
" - y_hat: The predictions made by the model\n",
376376
" \"\"\"\n",
377377
" y = y.astype(\"float32\") # cuML RandomForest needs the y labels to be float32\n",
378-
" return accuracy_score(y, y_hat, convert_dtype=True)\n",
378+
" return accuracy_score(y, y_hat)\n",
379379
"\n",
380380
"accuracy_wrapper_scorer = make_scorer(accuracy_score_wrapper)\n",
381-
"cuml_accuracy_scorer = make_scorer(accuracy_score, convert_dtype=True)"
381+
"cuml_accuracy_scorer = make_scorer(accuracy_score)"
382382
]
383383
},
384384
{
@@ -447,7 +447,7 @@
447447
" mode_str: User specifies what model it is to print the value\n",
448448
" \"\"\"\n",
449449
" y_pred = model.fit(X_train, y_train).predict(X_test)\n",
450-
" score = accuracy_score(y_pred, y_test.astype('float32'), convert_dtype=True)\n",
450+
" score = accuracy_score(y_pred, y_test.astype('float32'))\n",
451451
" \n",
452452
" print(\"{} model accuracy: {}\".format(mode_str, score))\n",
453453
" "

python/cuml/cuml/internals/mixins.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,15 @@ class ClassifierMixin:
236236
)
237237
@api_base_return_any_skipall
238238
@enable_device_interop
239-
def score(self, X, y, **kwargs):
239+
def score(self, X, y, sample_weight=None, **kwargs):
240240
"""
241241
Scoring function for classifier estimators based on mean accuracy.
242242
243243
"""
244-
from cuml.metrics.accuracy import accuracy_score
245-
246-
if hasattr(self, "handle"):
247-
handle = self.handle
248-
else:
249-
handle = None
244+
from cuml.metrics import accuracy_score
250245

251246
preds = self.predict(X, **kwargs)
252-
return accuracy_score(y, preds, handle=handle)
247+
return accuracy_score(y, preds, sample_weight=sample_weight)
253248

254249
@staticmethod
255250
def _more_static_tags():

python/cuml/cuml/metrics/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
set(cython_sources "")
18-
add_module_gpu_default("accuracy.pyx" ${accuracy_algo} ${metrics_algo})
1918
add_module_gpu_default("hinge_loss.pyx" ${hinge_loss_algo} ${metrics_algo})
2019
add_module_gpu_default("kl_divergence.pyx" ${kl_divergence_algo} ${metrics_algo})
2120
add_module_gpu_default("pairwise_distances.pyx" ${pairwise_distances_algo} ${metrics_algo})

python/cuml/cuml/metrics/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
2+
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -19,11 +19,10 @@
1919
from cuml.metrics.regression import mean_squared_error
2020
from cuml.metrics.regression import mean_squared_log_error
2121
from cuml.metrics.regression import mean_absolute_error
22-
from cuml.metrics.accuracy import accuracy_score
2322
from cuml.metrics.cluster.adjusted_rand_index import adjusted_rand_score
2423
from cuml.metrics._ranking import roc_auc_score
2524
from cuml.metrics._ranking import precision_recall_curve
26-
from cuml.metrics._classification import log_loss
25+
from cuml.metrics._classification import log_loss, accuracy_score
2726
from cuml.metrics.cluster.homogeneity_score import (
2827
cython_homogeneity_score as homogeneity_score,
2928
)

python/cuml/cuml/metrics/_classification.py

+107-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -13,16 +13,121 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import warnings
1617

17-
from cuml.internals.input_utils import input_to_cupy_array
1818
import cuml.internals
19+
from cuml.internals.input_utils import input_to_cupy_array
1920
from cuml.internals.safe_imports import cpu_only_import
2021
from cuml.internals.safe_imports import gpu_only_import
2122

2223
cp = gpu_only_import("cupy")
24+
cudf = gpu_only_import("cudf")
2325
np = cpu_only_import("numpy")
2426

2527

28+
def _input_to_cupy_or_cudf_series(x, check_rows=None):
29+
"""Coerce the input to a 1D cupy array or cudf Series.
30+
31+
For classification problems we need to support the full range
32+
of supported input dtypes. cupy cannot support string labels,
33+
and cudf cannot support float16. To handle this, we prefer cudf
34+
if the input is cudf, otherwise try to coerce to cupy, falling
35+
back to cudf if the dtype isn't supported.
36+
"""
37+
if isinstance(x, cudf.Series):
38+
# Drop the index so comparisons don't try to align on index
39+
out = x.reset_index(drop=True)
40+
n_cols = 1
41+
else:
42+
try:
43+
out, _, n_cols, _ = input_to_cupy_array(x)
44+
out = out.squeeze() # ensure 1D
45+
except ValueError:
46+
# Unsupported dtype, use cudf instead
47+
# Drop the index so comparisons don't try to align on index
48+
out = cudf.Series(x, nan_as_null=False, copy=False).reset_index(
49+
drop=True
50+
)
51+
n_cols = 1
52+
53+
n_rows = len(out)
54+
55+
if n_cols > 1:
56+
raise ValueError(f"Expected 1 column but got {n_cols} columns.")
57+
if check_rows is not None and n_rows != check_rows:
58+
raise ValueError(f"Expected {check_rows} rows but got {n_rows} rows.")
59+
60+
return out
61+
62+
63+
@cuml.internals.api_return_any()
64+
def accuracy_score(
65+
y_true, y_pred, *, sample_weight=None, normalize=True, **kwargs
66+
):
67+
"""
68+
Accuracy classification score.
69+
70+
Parameters
71+
----------
72+
y_true : array-like of shape (n_samples,)
73+
Ground truth (correct) labels.
74+
y_pred : array-like of shape (n_samples,)
75+
Predicted labels.
76+
sample_weight : array-like of shape (n_samples,)
77+
Sample weights.
78+
normalize : bool
79+
If ``False``, return the number of correctly classified samples.
80+
Otherwise, return the fraction of correctly classified samples.
81+
82+
Returns
83+
-------
84+
score : float
85+
The fraction of correctly classified samples, or the number of correctly
86+
classified samples if ``normalize == False``.
87+
"""
88+
89+
if kwargs:
90+
warnings.warn(
91+
"`convert_dtype` and `handle` were deprecated from `accuracy_score` "
92+
"in version 25.04 and will be removed in 25.06.",
93+
FutureWarning,
94+
)
95+
96+
y_true = _input_to_cupy_or_cudf_series(y_true)
97+
y_pred = _input_to_cupy_or_cudf_series(y_pred, check_rows=len(y_true))
98+
99+
# Categorical dtypes in cudf currently don't coerce nicely on equality,
100+
# we need to manually cast to cudf.Series and align dtypes.
101+
# This whole code block can be removed once
102+
# https://github.com/rapidsai/cudf/issues/18196 is resolved.
103+
if y_true.dtype == "category":
104+
if y_pred.dtype != y_true.dtype:
105+
y_pred = cudf.Series(y_pred, copy=False, nan_as_null=False).astype(
106+
y_true.dtype
107+
)
108+
elif y_pred.dtype == "category":
109+
y_true = cudf.Series(y_true, copy=False, nan_as_null=False).astype(
110+
y_pred.dtype
111+
)
112+
113+
if sample_weight is not None:
114+
sample_weight = input_to_cupy_array(
115+
sample_weight,
116+
check_dtype=[np.float32, np.float64, np.int32, np.int64],
117+
check_cols=1,
118+
check_rows=len(y_true),
119+
).array.squeeze() # ensure 1D
120+
121+
correct = y_true == y_pred
122+
123+
if normalize:
124+
return float(cp.average(correct, weights=sample_weight))
125+
elif sample_weight is not None:
126+
return float(cp.dot(correct, sample_weight))
127+
else:
128+
return float(cp.count_nonzero(correct))
129+
130+
26131
@cuml.internals.api_return_any()
27132
def log_loss(
28133
y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None

python/cuml/cuml/metrics/accuracy.pyx

-83
This file was deleted.

python/cuml/cuml/tests/test_kneighbors_classifier.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def test_neighborhood_predictions(
9999
@pytest.mark.parametrize("ncols", [50, 100])
100100
@pytest.mark.parametrize("n_neighbors", [2, 5, 10])
101101
@pytest.mark.parametrize("n_clusters", [2, 5, 10])
102-
def test_score(nrows, ncols, n_neighbors, n_clusters, datatype):
102+
@pytest.mark.parametrize("weighted", [False, True])
103+
def test_score(nrows, ncols, n_neighbors, n_clusters, datatype, weighted):
103104

104105
X, y = make_blobs(
105106
n_samples=nrows,
@@ -112,10 +113,18 @@ def test_score(nrows, ncols, n_neighbors, n_clusters, datatype):
112113
X = X.astype(np.float32)
113114
X_train, X_test, y_train, y_test = _build_train_test_data(X, y, datatype)
114115

116+
if weighted:
117+
sample_weight = np.random.default_rng(42).uniform(
118+
0.5, 1, size=len(X_test)
119+
)
120+
else:
121+
sample_weight = None
122+
115123
knn_cu = cuKNN(n_neighbors=n_neighbors)
116124
knn_cu.fit(X_train, y_train)
117125

118-
assert knn_cu.score(X_test, y_test) >= (1.0 - 0.004)
126+
score = knn_cu.score(X_test, y_test, sample_weight=sample_weight)
127+
assert score >= (1.0 - 0.004)
119128

120129

121130
@pytest.mark.parametrize("datatype", ["dataframe", "numpy"])

python/cuml/cuml/tests/test_meta_estimators.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
2+
# Copyright (c) 2021-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -35,7 +35,7 @@ def test_pipeline():
3535
pipe = Pipeline(steps=[("scaler", StandardScaler()), ("svc", SVC())])
3636
pipe.fit(X_train, y_train)
3737
score = pipe.score(X_test, y_test)
38-
assert score > 0.8
38+
assert score > 0.75
3939

4040

4141
def test_gridsearchCV():

0 commit comments

Comments
 (0)