-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathcustom_cuda_kernel.dp.cpp
93 lines (82 loc) · 3.46 KB
/
custom_cuda_kernel.dp.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <sycl/sycl.hpp>
inline void has_capability_or_fail(const sycl::device& dev,
const std::initializer_list<sycl::aspect>& props)
{
for (const auto& it : props) {
if (dev.has(it)) continue;
switch (it) {
case sycl::aspect::fp64:
throw std::runtime_error("'double' is not supported in '" +
dev.get_info<sycl::info::device::name>() + "' device");
break;
case sycl::aspect::fp16:
throw std::runtime_error("'half' is not supported in '" +
dev.get_info<sycl::info::device::name>() + "' device");
break;
default:
#define __SYCL_ASPECT(ASPECT, ID) \
case sycl::aspect::ASPECT: \
return #ASPECT;
#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string {
switch (AspectNum) {
#include <sycl/info/aspects.def>
#include <sycl/info/aspects_deprecated.def>
default: return "unknown aspect";
}
};
#undef __SYCL_ASPECT_DEPRECATED_ALIAS
#undef __SYCL_ASPECT_DEPRECATED
#undef __SYCL_ASPECT
throw std::runtime_error("'" + getAspectNameStr(it) + "' is not supported in '" +
dev.get_info<sycl::info::device::name>() + "' device");
}
break;
}
}
void param_update_kernel(const float* input, sycl::half* output, int size)
{
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
if (id < size) { output[id] = (sycl::half)input[id]; }
}
void launch_param_update(const float* input, sycl::half* output, int size, sycl::queue* stream)
{
int threads = 1024;
sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1);
sycl::range<3> block_dim(1, 1, threads);
{
has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(grid_dim * block_dim, block_dim),
[=](sycl::nd_item<3> item_ct1) { param_update_kernel(input, output, size); });
}
}
void param_update_kernel_half(const float* input, sycl::half* output, int size)
{
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
sycl::half2* output_cast = reinterpret_cast<sycl::half2*>(output);
if (id < size) {
float input_f = input[id];
sycl::half2* input_h = reinterpret_cast<sycl::half2*>(&input_f);
output_cast[id] = *input_h;
}
}
void launch_param_update_half(const float* input, sycl::half* output, int size, sycl::queue* stream)
{
int threads = 1024;
size /= 2;
sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1);
sycl::range<3> block_dim(1, 1, threads);
{
has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(grid_dim * block_dim, block_dim),
[=](sycl::nd_item<3> item_ct1) { param_update_kernel_half(input, output, size); });
}
}