Skip to content

Commit 9c17b09

Browse files
authored
Fix test_accuracy_score test on cudf.pandas build (#6439)
This fixes `test_accuracy_score` to still work when `cudf.pandas` is active. The failure had gone unnoticed since `cudf.pandas` builds are optional currently and have been flakey long enough that I've stopped inspecting them when they're red :/. More motivation to fix our test issues and make that test run non-optional. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Dante Gama Dessavre (https://github.com/dantegd) URL: #6439
1 parent 67e0cc0 commit 9c17b09

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ci/accel/scikit-learn-tests/run-tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111

1212
set -eu
1313

14-
pytest -p cuml.accel --pyargs sklearn -v $@
14+
pytest -p cuml.accel --pyargs sklearn -v "$@"

python/cuml/cuml/tests/test_metrics.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ def test_sklearn_search():
185185
("cupy", "numpy", "int64", "int32"),
186186
("numpy", "cupy", "float16", "float32"),
187187
("cudf", "cudf", "int32", "int64"),
188-
("numpy", "cudf", "object", "object"),
189-
("numpy", "cudf", "object", "category"),
190-
("cudf", "numpy", "category", "object"),
191-
("cudf", "cudf", "object", "object"),
192-
("cudf", "cudf", "object", "category"),
193-
("cudf", "cudf", "category", "object"),
188+
("numpy", "cudf", "str", "str"),
189+
("numpy", "cudf", "str", "category"),
190+
("cudf", "numpy", "category", "str"),
191+
("cudf", "cudf", "str", "str"),
192+
("cudf", "cudf", "str", "category"),
193+
("cudf", "cudf", "category", "str"),
194194
("cudf", "cudf", "category", "category"),
195195
],
196196
)
@@ -213,8 +213,8 @@ def test_accuracy_score(
213213
np_true = rng.randint(0, 3, N)
214214
np_pred = (rng.randint(0, 2, N) + np_true) % 3
215215
np_weight = rng.random(N).astype(weight_dtype) if weight_kind else None
216-
if true_dtype in ("object", "category"):
217-
assert pred_dtype in ("object", "category")
216+
if true_dtype in ("str", "category"):
217+
assert pred_dtype in ("str", "category")
218218
labels = np.array(["a", "b", "c"], dtype="object")
219219
np_true = labels.take(np_true)
220220
np_pred = labels.take(np_pred)

0 commit comments

Comments
 (0)