Skip to content

Commit 67e0cc0

Browse files
authored
Check if both ground truth and result are NaN in MSEObjectiveTest for RF (#6387)
If both results are NaNs, pass the test rather than attempting to `ASSERT_NEAR` on NaN values. Authors: - William Hicks (https://github.com/wphicks) - Jim Crist-Harif (https://github.com/jcrist) - Simon Adorf (https://github.com/csadorf) Approvers: - Simon Adorf (https://github.com/csadorf) - Dante Gama Dessavre (https://github.com/dantegd) URL: #6387
1 parent cc68aee commit 67e0cc0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

cpp/tests/sg/rf_test.cu

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <gtest/gtest.h>
4545
#include <test_utils.h>
4646

47+
#include <cmath>
4748
#include <cstddef>
4849
#include <memory>
4950
#include <random>
@@ -1208,7 +1209,11 @@ class ObjectiveTest : public ::testing::TestWithParam<ObjectiveTestParameters> {
12081209
NumLeftOfBin(cdf_hist, params.max_n_bins - 1),
12091210
NumLeftOfBin(cdf_hist, split_bin_index));
12101211

1211-
ASSERT_NEAR(ground_truth_gain, hypothesis_gain, params.tolerance);
1212+
// The gain may actually be NaN. If so, a comparison between the result and
1213+
// ground truth would yield false, even if they are both (correctly) NaNs.
1214+
if (!std::isnan(ground_truth_gain) || !std::isnan(hypothesis_gain)) {
1215+
ASSERT_NEAR(ground_truth_gain, hypothesis_gain, params.tolerance);
1216+
}
12121217
}
12131218
};
12141219

0 commit comments

Comments
 (0)