Skip to content

Commit b3d35ec

Browse files
committed
Add HIP test host-min-max
1 parent e597556 commit b3d35ec

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

Diff for: External/HIP/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ macro(create_local_hip_tests VariantSuffix)
1212
list(APPEND HIP_LOCAL_TESTS empty)
1313
list(APPEND HIP_LOCAL_TESTS with-fopenmp)
1414
list(APPEND HIP_LOCAL_TESTS saxpy)
15+
list(APPEND HIP_LOCAL_TESTS host-min-max)
1516
list(APPEND HIP_LOCAL_TESTS InOneWeekend)
1617
list(APPEND HIP_LOCAL_TESTS TheNextWeek)
1718

Diff for: External/HIP/host-min-max.hip

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <hip/hip_runtime.h>
2+
#include <iostream>
3+
#include <algorithm>
4+
#include <type_traits>
5+
#include <cstdlib> // For std::abort
6+
#include <typeinfo> // For typeid
7+
8+
std::string demangle(const char* name) {
9+
if (std::string(name) == "i") return "int";
10+
else if (std::string(name) == "f") return "float";
11+
else if (std::string(name) == "d") return "double";
12+
else if (std::string(name) == "j") return "unsigned int";
13+
else if (std::string(name) == "l") return "long";
14+
else if (std::string(name) == "m") return "unsigned long";
15+
else if (std::string(name) == "x") return "long long";
16+
else if (std::string(name) == "y") return "unsigned long long";
17+
else return std::string(name);
18+
}
19+
20+
void checkHipCall(hipError_t status, const char* msg) {
21+
if (status != hipSuccess) {
22+
std::cerr << "HIP Error: " << msg << " - " << hipGetErrorString(status) << std::endl;
23+
std::abort();
24+
}
25+
}
26+
27+
template<typename T1, typename T2>
28+
void compareResults(T1 hipResult, T2 stdResult, const std::string& testName) {
29+
using CommonType = typename std::common_type<T1, T2>::type;
30+
if (static_cast<CommonType>(hipResult) != static_cast<CommonType>(stdResult)) {
31+
std::cerr << testName << " mismatch: HIP result " << hipResult << " (" << demangle(typeid(hipResult).name()) << "), std result " << stdResult << " (" << demangle(typeid(stdResult).name()) << ")" << std::endl;
32+
std::abort();
33+
}
34+
}
35+
36+
template<typename T1, typename T2>
37+
void runTest(T1 a, T2 b) {
38+
std::cout << "\nTesting with values: " << a << " (" << demangle(typeid(a).name()) << ") and " << b << " (" << demangle(typeid(b).name()) << ")" << std::endl;
39+
40+
// Using std::min and std::max explicitly for host code to ensure clarity and correctness
41+
using CommonType = typename std::common_type<T1, T2>::type;
42+
CommonType stdMinResult = std::min<CommonType>(a, b);
43+
CommonType stdMaxResult = std::max<CommonType>(a, b);
44+
std::cout << "Host std::min result: " << stdMinResult << " (Type: " << demangle(typeid(stdMinResult).name()) << ")" << std::endl;
45+
std::cout << "Host std::max result: " << stdMaxResult << " (Type: " << demangle(typeid(stdMaxResult).name()) << ")" << std::endl;
46+
47+
// Using HIP's global min/max functions
48+
CommonType hipMinResult = min(a, b); // Note: This directly uses HIP's min, assuming it's correctly overloaded for host code
49+
CommonType hipMaxResult = max(a, b); // Note: This directly uses HIP's max, assuming it's correctly overloaded for host code
50+
std::cout << "Host HIP min result: " << hipMinResult << " (Type: " << demangle(typeid(hipMinResult).name()) << ")" << std::endl;
51+
std::cout << "Host HIP max result: " << hipMaxResult << " (Type: " << demangle(typeid(hipMaxResult).name()) << ")" << std::endl;
52+
53+
// Ensure the host HIP and std results match
54+
compareResults(hipMinResult, stdMinResult, "HIP vs std min");
55+
compareResults(hipMaxResult, stdMaxResult, "HIP vs std max");
56+
}
57+
58+
int main() {
59+
checkHipCall(hipSetDevice(0), "hipSetDevice failed");
60+
61+
runTest(10uLL, -5LL); // Testing with unsigned int and long long
62+
runTest(-15, 20u); // Testing with int and unsigned int
63+
runTest(2147483647, 2147483648u); // Testing with int and unsigned int
64+
runTest(-922337203685477580LL, 922337203685477580uLL); // Testing with long long and unsigned long long
65+
runTest(2.5f, 3.14159); // Testing with float and double
66+
67+
std::cout << "\nPass\n"; // Output "Pass" at the end if all tests pass without aborting
68+
return 0;
69+
}

0 commit comments

Comments
 (0)