Skip to content

Commit 2b07513

Browse files
authored
[FEAT]: Golang Bindings for pinned host memory (#519)
## Describe the changes This PR adds the capability to pin host memory in golang bindings allowing data transfers to be quicker. Memory can be pinned once for multiple devices by passing the flag `cuda_runtime.CudaHostRegisterPortable` or `cuda_runtime.CudaHostAllocPortable` depending on how pinned memory is called
1 parent 7831f7b commit 2b07513

File tree

25 files changed

+868
-48
lines changed

25 files changed

+868
-48
lines changed

wrappers/golang/core/slice.go

+29
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,35 @@ func (h HostSlice[T]) AsUnsafePointer() unsafe.Pointer {
197197
return unsafe.Pointer(&h[0])
198198
}
199199

200+
// Registers host memory as pinned, allowing the GPU to read data from the host quicker and save GPU memory space.
201+
// Memory pinned using this function should be unpinned using [Unpin]
202+
func (h HostSlice[T]) Pin(flags cr.RegisterPinnedFlags) cr.CudaError {
203+
_, err := cr.RegisterPinned(h.AsUnsafePointer(), h.SizeOfElement()*h.Len(), flags)
204+
return err
205+
}
206+
207+
// Unregisters host memory as pinned
208+
func (h HostSlice[T]) Unpin() cr.CudaError {
209+
return cr.FreeRegisteredPinned(h.AsUnsafePointer())
210+
}
211+
212+
// Allocates new host memory as pinned and copies the HostSlice data to the newly allocated area
213+
// Memory pinned using this function should be unpinned using [FreePinned]
214+
func (h HostSlice[T]) AllocPinned(flags cr.AllocPinnedFlags) (HostSlice[T], cr.CudaError) {
215+
pinnedMemPointer, err := cr.AllocPinned(h.SizeOfElement()*h.Len(), flags)
216+
if err != cr.CudaSuccess {
217+
return nil, err
218+
}
219+
pinnedMem := unsafe.Slice((*T)(pinnedMemPointer), h.Len())
220+
copy(pinnedMem, h)
221+
return pinnedMem, cr.CudaSuccess
222+
}
223+
224+
// Unpins host memory that was pinned using [AllocPinned]
225+
func (h HostSlice[T]) FreePinned() cr.CudaError {
226+
return cr.FreeAllocPinned(h.AsUnsafePointer())
227+
}
228+
200229
func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *DeviceSlice {
201230
size := h.Len() * h.SizeOfElement()
202231
if shouldAllocate {

wrappers/golang/core/slice_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"unsafe"
77

88
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core/internal"
9+
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
910
"github.com/stretchr/testify/assert"
1011
)
1112

@@ -222,3 +223,26 @@ func TestSliceRanges(t *testing.T) {
222223
hostSliceRange.CopyFromDevice(&deviceSliceRange)
223224
assert.Equal(t, hostSlice[2:6], hostSliceRange)
224225
}
226+
227+
func TestHostSlicePinning(t *testing.T) {
228+
data := []int{1, 2, 3, 4, 5, 7, 8, 9}
229+
dataHostSlice := HostSliceFromElements(data)
230+
err := dataHostSlice.Pin(cuda_runtime.CudaHostRegisterDefault)
231+
assert.Equal(t, cuda_runtime.CudaSuccess, err)
232+
err = dataHostSlice.Pin(cuda_runtime.CudaHostRegisterDefault)
233+
assert.Equal(t, cuda_runtime.CudaErrorHostMemoryAlreadyRegistered, err)
234+
235+
err = dataHostSlice.Unpin()
236+
assert.Equal(t, cuda_runtime.CudaSuccess, err)
237+
err = dataHostSlice.Unpin()
238+
assert.Equal(t, cuda_runtime.CudaErrorHostMemoryNotRegistered, err)
239+
240+
pinnedMem, err := dataHostSlice.AllocPinned(cuda_runtime.CudaHostAllocDefault)
241+
assert.Equal(t, cuda_runtime.CudaSuccess, err)
242+
assert.ElementsMatch(t, dataHostSlice, pinnedMem)
243+
244+
err = pinnedMem.FreePinned()
245+
assert.Equal(t, cuda_runtime.CudaSuccess, err)
246+
err = pinnedMem.FreePinned()
247+
assert.Equal(t, cuda_runtime.CudaErrorInvalidValue, err)
248+
}

wrappers/golang/cuda_runtime/const.go

+157
Large diffs are not rendered by default.

wrappers/golang/cuda_runtime/device_context.go

+48-40
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ func GetDeviceFromPointer(ptr unsafe.Pointer) int {
7474
return int(cCudaPointerAttributes.device)
7575
}
7676

77+
func GetDeviceAttribute(attr DeviceAttribute, device int) int {
78+
var res int
79+
cRes := (*C.int)(unsafe.Pointer(&res))
80+
cDevice := (C.int)(device)
81+
C.cudaDeviceGetAttribute(cRes, attr, cDevice)
82+
return res
83+
}
84+
7785
// RunOnDevice forces the provided function to run all GPU related calls within it
7886
// on the same host thread and therefore the same GPU device.
7987
//
@@ -84,46 +92,46 @@ func GetDeviceFromPointer(ptr unsafe.Pointer) int {
8492
//
8593
// As an example:
8694
//
87-
// cr.RunOnDevice(i, func(args ...any) {
88-
// defer wg.Done()
89-
// cfg := GetDefaultMSMConfig()
90-
// stream, _ := cr.CreateStream()
91-
// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
92-
// size := 1 << power
93-
//
94-
// // This will always print "Inner goroutine device: 0"
95-
// // go func () {
96-
// // device, _ := cr.GetDevice()
97-
// // fmt.Println("Inner goroutine device: ", device)
98-
// // }()
99-
// // To force the above goroutine to same device as the wrapping function:
100-
// // RunOnDevice(i, func(arg ...any) {
101-
// // device, _ := cr.GetDevice()
102-
// // fmt.Println("Inner goroutine device: ", device)
103-
// // })
104-
//
105-
// scalars := GenerateScalars(size)
106-
// points := GenerateAffinePoints(size)
107-
//
108-
// var p Projective
109-
// var out core.DeviceSlice
110-
// _, e := out.MallocAsync(p.Size(), p.Size(), stream)
111-
// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
112-
// cfg.Ctx.Stream = &stream
113-
// cfg.IsAsync = true
114-
//
115-
// e = Msm(scalars, points, &cfg, out)
116-
// assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
117-
//
118-
// outHost := make(core.HostSlice[Projective], 1)
119-
//
120-
// cr.SynchronizeStream(&stream)
121-
// outHost.CopyFromDevice(&out)
122-
// out.Free()
123-
// // Check with gnark-crypto
124-
// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
125-
// }
126-
// }, i)
95+
// cr.RunOnDevice(i, func(args ...any) {
96+
// defer wg.Done()
97+
// cfg := GetDefaultMSMConfig()
98+
// stream, _ := cr.CreateStream()
99+
// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
100+
// size := 1 << power
101+
102+
// // This will always print "Inner goroutine device: 0"
103+
// // go func () {
104+
// // device, _ := cr.GetDevice()
105+
// // fmt.Println("Inner goroutine device: ", device)
106+
// // }()
107+
// // To force the above goroutine to same device as the wrapping function:
108+
// // RunOnDevice(i, func(arg ...any) {
109+
// // device, _ := cr.GetDevice()
110+
// // fmt.Println("Inner goroutine device: ", device)
111+
// // })
112+
113+
// scalars := GenerateScalars(size)
114+
// points := GenerateAffinePoints(size)
115+
116+
// var p Projective
117+
// var out core.DeviceSlice
118+
// _, e := out.MallocAsync(p.Size(), p.Size(), stream)
119+
// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
120+
// cfg.Ctx.Stream = &stream
121+
// cfg.IsAsync = true
122+
123+
// e = Msm(scalars, points, &cfg, out)
124+
// assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
125+
126+
// outHost := make(core.HostSlice[Projective], 1)
127+
128+
// cr.SynchronizeStream(&stream)
129+
// outHost.CopyFromDevice(&out)
130+
// out.Free()
131+
// // Check with gnark-crypto
132+
// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
133+
// }
134+
// }, i)
127135
func RunOnDevice(deviceId int, funcToRun func(args ...any), args ...any) {
128136
go func(id int) {
129137
defer runtime.UnlockOSThread()

wrappers/golang/cuda_runtime/memory.go

+40
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package cuda_runtime
99
import "C"
1010

1111
import (
12+
// "runtime"
1213
"unsafe"
1314
)
1415

@@ -58,6 +59,45 @@ func FreeAsync(devicePtr unsafe.Pointer, stream Stream) CudaError {
5859
return err
5960
}
6061

62+
func AllocPinned(size int, flags AllocPinnedFlags) (unsafe.Pointer, CudaError) {
63+
cSize := (C.size_t)(size)
64+
var hostPtr unsafe.Pointer
65+
ret := C.cudaHostAlloc(&hostPtr, cSize, flags)
66+
err := (CudaError)(ret)
67+
if err != CudaSuccess {
68+
return nil, err
69+
}
70+
71+
return hostPtr, CudaSuccess
72+
}
73+
74+
func GetHostFlags(ptr unsafe.Pointer) (flag uint) {
75+
cFlag := (C.uint)(flag)
76+
C.cudaHostGetFlags(&cFlag, ptr)
77+
return
78+
}
79+
80+
func FreeAllocPinned(hostPtr unsafe.Pointer) CudaError {
81+
return (CudaError)(C.cudaFreeHost(hostPtr))
82+
}
83+
84+
func RegisterPinned(hostPtr unsafe.Pointer, size int, flags RegisterPinnedFlags) (unsafe.Pointer, CudaError) {
85+
cSize := (C.size_t)(size)
86+
// This is required since there are greater values of RegisterPinnedFlags which we do not support currently
87+
flags = flags & 3
88+
ret := C.cudaHostRegister(hostPtr, cSize, flags)
89+
err := (CudaError)(ret)
90+
if err != CudaSuccess {
91+
return nil, err
92+
}
93+
94+
return hostPtr, CudaSuccess
95+
}
96+
97+
func FreeRegisteredPinned(hostPtr unsafe.Pointer) CudaError {
98+
return (CudaError)(C.cudaHostUnregister(hostPtr))
99+
}
100+
61101
func CopyFromDevice(hostDst, deviceSrc unsafe.Pointer, size uint) (unsafe.Pointer, CudaError) {
62102
cCount := (C.size_t)(size)
63103
ret := C.cudaMemcpy(hostDst, deviceSrc, cCount, uint32(CudaMemcpyDeviceToHost))

wrappers/golang/cuda_runtime/memory_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,27 @@ func TestCopyFromToHost(t *testing.T) {
4040
assert.Equal(t, CudaSuccess, err, "Couldn't copy to device due to %v", err)
4141
assert.Equal(t, someInts, someInts2, "Elements of host slices do not match. Copying from/to host failed")
4242
}
43+
44+
func TestRegisterUnregisterPinned(t *testing.T) {
45+
data := []int{1, 2, 3, 4, 5, 7, 8, 9}
46+
dataUnsafe := unsafe.Pointer(&data[0])
47+
_, err := RegisterPinned(dataUnsafe, int(unsafe.Sizeof(data[0])*9), CudaHostRegisterDefault)
48+
assert.Equal(t, CudaSuccess, err)
49+
_, err = RegisterPinned(dataUnsafe, int(unsafe.Sizeof(data[0])*9), CudaHostRegisterDefault)
50+
assert.Equal(t, CudaErrorHostMemoryAlreadyRegistered, err)
51+
52+
err = FreeRegisteredPinned(dataUnsafe)
53+
assert.Equal(t, CudaSuccess, err)
54+
err = FreeRegisteredPinned(dataUnsafe)
55+
assert.Equal(t, CudaErrorHostMemoryNotRegistered, err)
56+
}
57+
58+
func TestAllocFreePinned(t *testing.T) {
59+
pinnedMemPointer, err := AllocPinned(int(unsafe.Sizeof(1)*9), CudaHostAllocDefault)
60+
assert.Equal(t, CudaSuccess, err)
61+
62+
err = FreeAllocPinned(pinnedMemPointer)
63+
assert.Equal(t, CudaSuccess, err)
64+
err = FreeAllocPinned(pinnedMemPointer)
65+
assert.Equal(t, CudaErrorInvalidValue, err)
66+
}

wrappers/golang/curves/bls12377/tests/g2_msm_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,58 @@ func TestMSMG2(t *testing.T) {
146146

147147
}
148148
}
149+
150+
func TestMSMG2PinnedHostMemory(t *testing.T) {
151+
cfg := g2.G2GetDefaultMSMConfig()
152+
for _, power := range []int{10} {
153+
size := 1 << power
154+
155+
scalars := icicleBls12_377.GenerateScalars(size)
156+
points := g2.G2GenerateAffinePoints(size)
157+
158+
pinnable := cr.GetDeviceAttribute(cr.CudaDevAttrHostRegisterSupported, 0)
159+
lockable := cr.GetDeviceAttribute(cr.CudaDevAttrPageableMemoryAccessUsesHostPageTables, 0)
160+
161+
pinnableAndLockable := pinnable == 1 && lockable == 0
162+
163+
var pinnedPoints core.HostSlice[g2.G2Affine]
164+
if pinnableAndLockable {
165+
points.Pin(cr.CudaHostRegisterDefault)
166+
pinnedPoints, _ = points.AllocPinned(cr.CudaHostAllocDefault)
167+
assert.Equal(t, points, pinnedPoints, "Allocating newly pinned memory resulted in bad points")
168+
}
169+
170+
var p g2.G2Projective
171+
var out core.DeviceSlice
172+
_, e := out.Malloc(p.Size(), p.Size())
173+
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
174+
outHost := make(core.HostSlice[g2.G2Projective], 1)
175+
176+
e = g2.G2Msm(scalars, points, &cfg, out)
177+
assert.Equal(t, e, cr.CudaSuccess, "Msm allocated pinned host mem failed")
178+
179+
outHost.CopyFromDevice(&out)
180+
// Check with gnark-crypto
181+
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
182+
183+
if pinnableAndLockable {
184+
e = g2.G2Msm(scalars, pinnedPoints, &cfg, out)
185+
assert.Equal(t, e, cr.CudaSuccess, "Msm registered pinned host mem failed")
186+
187+
outHost.CopyFromDevice(&out)
188+
// Check with gnark-crypto
189+
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, pinnedPoints, outHost[0]))
190+
191+
}
192+
193+
out.Free()
194+
195+
if pinnableAndLockable {
196+
points.Unpin()
197+
pinnedPoints.FreePinned()
198+
}
199+
}
200+
}
149201
func TestMSMG2GnarkCryptoTypes(t *testing.T) {
150202
cfg := g2.G2GetDefaultMSMConfig()
151203
for _, power := range []int{3} {

wrappers/golang/curves/bls12377/tests/msm_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,58 @@ func TestMSM(t *testing.T) {
106106

107107
}
108108
}
109+
110+
func TestMSMPinnedHostMemory(t *testing.T) {
111+
cfg := msm.GetDefaultMSMConfig()
112+
for _, power := range []int{10} {
113+
size := 1 << power
114+
115+
scalars := icicleBls12_377.GenerateScalars(size)
116+
points := icicleBls12_377.GenerateAffinePoints(size)
117+
118+
pinnable := cr.GetDeviceAttribute(cr.CudaDevAttrHostRegisterSupported, 0)
119+
lockable := cr.GetDeviceAttribute(cr.CudaDevAttrPageableMemoryAccessUsesHostPageTables, 0)
120+
121+
pinnableAndLockable := pinnable == 1 && lockable == 0
122+
123+
var pinnedPoints core.HostSlice[icicleBls12_377.Affine]
124+
if pinnableAndLockable {
125+
points.Pin(cr.CudaHostRegisterDefault)
126+
pinnedPoints, _ = points.AllocPinned(cr.CudaHostAllocDefault)
127+
assert.Equal(t, points, pinnedPoints, "Allocating newly pinned memory resulted in bad points")
128+
}
129+
130+
var p icicleBls12_377.Projective
131+
var out core.DeviceSlice
132+
_, e := out.Malloc(p.Size(), p.Size())
133+
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
134+
outHost := make(core.HostSlice[icicleBls12_377.Projective], 1)
135+
136+
e = msm.Msm(scalars, points, &cfg, out)
137+
assert.Equal(t, e, cr.CudaSuccess, "Msm allocated pinned host mem failed")
138+
139+
outHost.CopyFromDevice(&out)
140+
// Check with gnark-crypto
141+
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
142+
143+
if pinnableAndLockable {
144+
e = msm.Msm(scalars, pinnedPoints, &cfg, out)
145+
assert.Equal(t, e, cr.CudaSuccess, "Msm registered pinned host mem failed")
146+
147+
outHost.CopyFromDevice(&out)
148+
// Check with gnark-crypto
149+
assert.True(t, testAgainstGnarkCryptoMsm(scalars, pinnedPoints, outHost[0]))
150+
151+
}
152+
153+
out.Free()
154+
155+
if pinnableAndLockable {
156+
points.Unpin()
157+
pinnedPoints.FreePinned()
158+
}
159+
}
160+
}
109161
func TestMSMGnarkCryptoTypes(t *testing.T) {
110162
cfg := msm.GetDefaultMSMConfig()
111163
for _, power := range []int{3} {

wrappers/golang/curves/bls12377/tests/ntt_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,12 @@ func TestNttDeviceAsync(t *testing.T) {
151151

152152
func TestNttBatch(t *testing.T) {
153153
cfg := ntt.GetDefaultNttConfig()
154+
largestTestSize := 12
154155
largestBatchSize := 100
155156
scalars := bls12_377.GenerateScalars(1 << largestTestSize * largestBatchSize)
156157

157158
for _, size := range []int{4, largestTestSize} {
158-
for _, batchSize := range []int{1, 16, largestBatchSize} {
159+
for _, batchSize := range []int{2, 16, largestBatchSize} {
159160
testSize := 1 << size
160161
totalSize := testSize * batchSize
161162

0 commit comments

Comments
 (0)