Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit e440303

Browse files
committed
Try #334:
2 parents 88d7ef9 + 0c0573b commit e440303

File tree

5 files changed

+50
-3
lines changed

5 files changed

+50
-3
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
55
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
66
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
77
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
8+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1011
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -20,3 +21,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021

2122
[targets]
2223
test = ["Test", "BenchmarkTools", "SpecialFunctions"]
24+
25+
[compat]
26+
julia = ">= 1.1"

REQUIRE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
julia 1.0
1+
julia 1.1
22
CUDAdrv 1.0
33
LLVM 0.9.14
44
CUDAapi 0.4.0

src/CUDAnative.jl

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include(joinpath("device", "cuda_intrinsics.jl"))
3131
include(joinpath("device", "runtime_intrinsics.jl"))
3232

3333
include("compiler.jl")
34+
include("context.jl")
3435
include("execution.jl")
3536
include("reflection.jl")
3637

src/context.jl

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
##
2+
# Implements contextual dispatch through Cassette.jl
3+
# Goals:
4+
# - Rewrite common CPU functions to appropriate GPU intrinsics
5+
#
6+
# TODO:
7+
# - error (erf, ...)
8+
# - pow
9+
# - min, max
10+
# - mod, rem
11+
# - gamma
12+
# - bessel
13+
# - distributions
14+
# - unsorted
15+
16+
using Cassette
17+
18+
Cassette.@context CUDACtx
19+
const cudactx = CUDACtx()
20+
21+
# libdevice.jl
22+
for f in (:cos, :cospi, :sin, :sinpi, :tan,
23+
:acos, :asin, :atan,
24+
:cosh, :sinh, :tanh,
25+
:acosh, :asinh, :atanh,
26+
:log, :log10, :log1p, :log2,
27+
:exp, :exp2, :exp10, :expm1, :ldexp,
28+
:isfinite, :isinf, :isnan,
29+
:signbit, :abs,
30+
:sqrt, :cbrt,
31+
:ceil, :floor,)
32+
@eval function Cassette.overdub(ctx::CUDActx, ::typeof(Base.$f), x::Union{Float32, Float64})
33+
@Base._inline_meta
34+
return CUDAnative.$f(x)
35+
end
36+
end
37+
38+
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
39+
40+

src/execution.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ kernel to determine the launch configuration:
172172
GC.@preserve args begin
173173
kernel_args = cudaconvert.(args)
174174
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
175-
kernel = cufunction(f, kernel_tt; compilation_kwargs)
175+
kernel_f = contextualize(f)
176+
kernel = cufunction(kernel_f, kernel_tt; compilation_kwargs)
176177
kernel(kernel_args...; launch_kwargs)
177178
end
178179
"""
@@ -202,7 +203,8 @@ macro cuda(ex...)
202203
GC.@preserve $(vars...) begin
203204
local kernel_args = cudaconvert.(($(var_exprs...),))
204205
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
205-
local kernel = cufunction($(esc(f)), kernel_tt; $(map(esc, compiler_kwargs)...))
206+
local kernel_f = contextualize($(esc(f)))
207+
local kernel = cufunction(kernel_f, kernel_tt; $(map(esc, compiler_kwargs)...))
206208
kernel(kernel_args...; $(map(esc, call_kwargs)...))
207209
end
208210
end)

0 commit comments

Comments
 (0)