Skip to content

Commit 29fe5de

Browse files
authored
Add capability to set a default device for all threads (#699)
1 parent 0c33c74 commit 29fe5de

File tree

17 files changed

+245
-55
lines changed

17 files changed

+245
-55
lines changed

docs/docs/icicle/programmers_guide/cpp.md

+15
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ eIcicleError result = icicle_set_device(device);
3232
eIcicleError result = icicle_get_active_device(device);
3333
```
3434
35+
### Setting and Getting the Default Device
36+
37+
You can set the default device for all threads:
38+
39+
```cpp
40+
icicle::Device device = {"CUDA", 0}; // or other
41+
eIcicleError result = icicle_set_default_device(device);
42+
```
43+
44+
:::caution
45+
46+
Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [icicle_set_device](#setting-and-getting-active-device) should be used instead.
47+
48+
:::
49+
3550
### Querying Device Information
3651

3752
Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:

docs/docs/icicle/programmers_guide/general.md

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ ICICLE provides a device abstraction layer that allows you to interact with diff
8585

8686
- **Loading Backends**: Backends are loaded dynamically based on the environment configuration or a specified path.
8787
- **Setting Active Device**: The active device for a thread can be set, allowing for targeted computation on a specific device.
88+
- **Setting Default Device**: The default device for any thread without an active device can be set, removing the need to specify an alternative device on each thread. This is especially useful when running on a backend that is not the built-in CPU backend which is the default device to start.
8889

8990
## Streams
9091

docs/docs/icicle/programmers_guide/go.md

+16-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,27 @@ result := runtime.LoadBackend("/path/to/backend/installdir", true)
2727
You can set the active device for the current thread and retrieve it when needed:
2828

2929
```go
30-
device = runtime.CreateDevice("CUDA", 0) // or other
30+
device := runtime.CreateDevice("CUDA", 0) // or other
3131
result := runtime.SetDevice(device)
3232
// or query current (thread) device
3333
activeDevice := runtime.GetActiveDevice()
3434
```
3535

36+
### Setting and Getting the Default Device
37+
38+
You can set the default device for all threads:
39+
40+
```go
41+
device := runtime.CreateDevice("CUDA", 0) // or other
42+
defaultDevice := runtime.SetDefaultDevice(device);
43+
```
44+
45+
:::caution
46+
47+
Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [runtime.SetDevice](#setting-and-getting-active-device) should be used instead.
48+
49+
:::
50+
3651
### Querying Device Information
3752

3853
Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:

docs/docs/icicle/programmers_guide/rust.md

+15
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ icicle_runtime::set_device(&device).unwrap();
5555
let active_device = icicle_runtime::get_active_device().unwrap();
5656
```
5757
58+
### Setting and Getting the Default Device
59+
60+
You can set the default device for all threads:
61+
62+
```caution
63+
let device = Device::new("CUDA", 0); // or other
64+
let default_device = icicle_runtime::set_default_device(device);
65+
```
66+
67+
:::note
68+
69+
Setting a default device should be done **once** from the main thread of the application. If another device or backend is needed for a specific thread [icicle_runtime::set_device](#setting-and-getting-active-device) should be used instead.
70+
71+
:::
72+
5873
### Querying Device Information
5974
6075
Retrieve the number of available devices and check if a pointer is allocated on the host or on the active device:

icicle/include/icicle/device_api.h

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ namespace icicle {
188188

189189
public:
190190
static eIcicleError set_thread_local_device(const Device& device);
191+
static eIcicleError set_default_device(const Device& device);
191192
static const Device& get_thread_local_device();
192193
static const DeviceAPI* get_thread_local_deviceAPI();
193194
static DeviceTracker& get_global_memory_tracker() { return sMemTracker; }

icicle/include/icicle/runtime.h

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ extern "C" eIcicleError icicle_load_backend_from_env_or_default();
3636
*/
3737
extern "C" eIcicleError icicle_set_device(const icicle::Device& device);
3838

39+
/**
40+
* @brief Set default device for all threads
41+
*
42+
43+
* @return eIcicleError::SUCCESS if successful, otherwise throws INVALID_DEVICE
44+
*/
45+
extern "C" eIcicleError icicle_set_default_device(const icicle::Device& device);
46+
3947
/**
4048
* @brief Get active device for thread
4149
*

icicle/src/device_api.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ namespace icicle {
5858

5959
const Device& get_default_device() { return m_default_device; }
6060

61+
eIcicleError set_default_device(const Device& dev)
62+
{
63+
if (!is_device_registered(dev.type)) {
64+
ICICLE_LOG_ERROR << "Device type " + std::string(dev.type) + " is not valid as it has not been registered";
65+
return eIcicleError::INVALID_DEVICE;
66+
}
67+
68+
m_default_device = dev;
69+
return eIcicleError::SUCCESS;
70+
}
71+
6172
std::vector<std::string> get_registered_devices_list()
6273
{
6374
std::vector<std::string> registered_devices;
@@ -116,6 +127,11 @@ namespace icicle {
116127
return default_deviceAPI.get();
117128
}
118129

130+
eIcicleError DeviceAPI::set_default_device(const Device& dev)
131+
{
132+
return DeviceAPIRegistry::Global().set_default_device(dev);
133+
}
134+
119135
/********************************************************************************** */
120136

121137
DeviceAPI* get_deviceAPI(const Device& device) { return DeviceAPIRegistry::Global().get_deviceAPI(device).get(); }

icicle/src/runtime.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ using namespace icicle;
1414

1515
extern "C" eIcicleError icicle_set_device(const Device& device) { return DeviceAPI::set_thread_local_device(device); }
1616

17+
extern "C" eIcicleError icicle_set_default_device(const Device& device)
18+
{
19+
return DeviceAPI::set_default_device(device);
20+
}
21+
1722
extern "C" eIcicleError icicle_get_active_device(icicle::Device& device)
1823
{
1924
const Device& active_device = DeviceAPI::get_thread_local_device();

icicle/tests/test_device_api.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
#include <gtest/gtest.h>
3+
#include <thread>
34
#include <iostream>
45

56
#include "icicle/runtime.h"
@@ -19,6 +20,36 @@ TEST_F(DeviceApiTest, UnregisteredDeviceError)
1920
EXPECT_ANY_THROW(get_deviceAPI(dev));
2021
}
2122

23+
TEST_F(DeviceApiTest, SetDefaultDevice)
24+
{
25+
icicle::Device active_dev = {UNKOWN_DEVICE, -1};
26+
27+
icicle::Device cpu_dev = {s_ref_device, 0};
28+
EXPECT_NO_THROW(icicle_set_device(cpu_dev));
29+
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
30+
31+
ASSERT_EQ(cpu_dev, active_dev);
32+
33+
active_dev = {UNKOWN_DEVICE, -1};
34+
35+
icicle::Device gpu_dev = {s_main_device, 0};
36+
EXPECT_NO_THROW(icicle_set_default_device(gpu_dev));
37+
38+
// setting a new default device doesn't override already set local thread devices
39+
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
40+
ASSERT_EQ(cpu_dev, active_dev);
41+
42+
active_dev = {UNKOWN_DEVICE, -1};
43+
auto thread_func = [&active_dev, &gpu_dev]() {
44+
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
45+
ASSERT_EQ(gpu_dev, active_dev);
46+
};
47+
48+
std::thread worker_thread(thread_func);
49+
50+
worker_thread.join();
51+
}
52+
2253
TEST_F(DeviceApiTest, MemoryCopySync)
2354
{
2455
int input[2] = {1, 2};

wrappers/golang/runtime/device.go

+6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ func SetDevice(device *Device) EIcicleError {
5050
return EIcicleError(cErr)
5151
}
5252

53+
func SetDefaultDevice(device *Device) EIcicleError {
54+
cDevice := (*C.Device)(unsafe.Pointer(device))
55+
cErr := C.icicle_set_default_device(cDevice)
56+
return EIcicleError(cErr)
57+
}
58+
5359
func GetActiveDevice() (*Device, EIcicleError) {
5460
device := CreateDevice("invalid", -1)
5561
cDevice := (*C.Device)(unsafe.Pointer(&device))

wrappers/golang/runtime/include/runtime.h

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ typedef struct DeviceProperties DeviceProperties;
1313
int icicle_load_backend(const char* path, bool is_recursive);
1414
int icicle_load_backend_from_env_or_default();
1515
int icicle_set_device(const Device* device);
16+
int icicle_set_default_device(const Device* device);
1617
int icicle_get_active_device(Device* device);
1718
int icicle_is_host_memory(const void* ptr);
1819
int icicle_is_active_device_memory(const void* ptr);
+85-33
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,122 @@
11
package tests
22

33
import (
4+
"fmt"
45
"os/exec"
6+
"runtime"
7+
"strconv"
8+
"strings"
9+
"syscall"
510
"testing"
611

7-
"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
12+
icicle_runtime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
813

914
"github.com/stretchr/testify/assert"
1015
)
1116

1217
func TestGetDeviceType(t *testing.T) {
1318
expectedDeviceName := "test"
14-
config := runtime.CreateDevice(expectedDeviceName, 0)
19+
config := icicle_runtime.CreateDevice(expectedDeviceName, 0)
1520
assert.Equal(t, expectedDeviceName, config.GetDeviceType())
1621

1722
expectedDeviceNameLong := "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest"
18-
configLargeName := runtime.CreateDevice(expectedDeviceNameLong, 1)
23+
configLargeName := icicle_runtime.CreateDevice(expectedDeviceNameLong, 1)
1924
assert.NotEqual(t, expectedDeviceNameLong, configLargeName.GetDeviceType())
2025
}
2126

2227
func TestIsDeviceAvailable(t *testing.T) {
23-
runtime.LoadBackendFromEnvOrDefault()
24-
dev := runtime.CreateDevice("CUDA", 0)
25-
_ = runtime.SetDevice(&dev)
26-
res, err := runtime.GetDeviceCount()
27-
28-
expectedNumDevices, error := exec.Command("nvidia-smi", "-L", "|", "wc", "-l").Output()
29-
if error != nil {
30-
t.Skip("Failed to get number of devices")
28+
dev := icicle_runtime.CreateDevice("CUDA", 0)
29+
_ = icicle_runtime.SetDevice(&dev)
30+
res, err := icicle_runtime.GetDeviceCount()
31+
32+
smiCommand := exec.Command("nvidia-smi", "-L")
33+
smiCommandStdout, _ := smiCommand.StdoutPipe()
34+
wcCommand := exec.Command("wc", "-l")
35+
wcCommand.Stdin = smiCommandStdout
36+
37+
smiCommand.Start()
38+
39+
expectedNumDevicesRaw, wcErr := wcCommand.Output()
40+
smiCommand.Wait()
41+
42+
expectedNumDevicesAsString := strings.TrimRight(string(expectedNumDevicesRaw), " \n\r\t")
43+
expectedNumDevices, _ := strconv.Atoi(expectedNumDevicesAsString)
44+
if wcErr != nil {
45+
t.Skip("Failed to get number of devices:", wcErr)
3146
}
3247

33-
assert.Equal(t, runtime.Success, err)
48+
assert.Equal(t, icicle_runtime.Success, err)
3449
assert.Equal(t, expectedNumDevices, res)
3550

36-
err = runtime.LoadBackendFromEnvOrDefault()
37-
assert.Equal(t, runtime.Success, err)
38-
devCuda := runtime.CreateDevice("CUDA", 0)
39-
assert.True(t, runtime.IsDeviceAvailable(&devCuda))
40-
devCpu := runtime.CreateDevice("CPU", 0)
41-
assert.True(t, runtime.IsDeviceAvailable(&devCpu))
42-
devInvalid := runtime.CreateDevice("invalid", 0)
43-
assert.False(t, runtime.IsDeviceAvailable(&devInvalid))
51+
assert.Equal(t, icicle_runtime.Success, err)
52+
devCuda := icicle_runtime.CreateDevice("CUDA", 0)
53+
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCuda))
54+
devCpu := icicle_runtime.CreateDevice("CPU", 0)
55+
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCpu))
56+
devInvalid := icicle_runtime.CreateDevice("invalid", 0)
57+
assert.False(t, icicle_runtime.IsDeviceAvailable(&devInvalid))
58+
}
59+
60+
func TestSetDefaultDevice(t *testing.T) {
61+
runtime.LockOSThread()
62+
defer runtime.UnlockOSThread()
63+
tidOuter := syscall.Gettid()
64+
65+
gpuDevice := icicle_runtime.CreateDevice("CUDA", 0)
66+
icicle_runtime.SetDefaultDevice(&gpuDevice)
67+
68+
activeDevice, err := icicle_runtime.GetActiveDevice()
69+
assert.Equal(t, icicle_runtime.Success, err)
70+
assert.Equal(t, gpuDevice, *activeDevice)
71+
72+
done := make(chan struct{}, 1)
73+
go func() {
74+
runtime.LockOSThread()
75+
defer runtime.UnlockOSThread()
76+
77+
// Ensure we are operating on an OS thread other than the original one
78+
tidInner := syscall.Gettid()
79+
for tidInner == tidOuter {
80+
fmt.Println("Locked thread is the same as original, getting new locked thread")
81+
runtime.UnlockOSThread()
82+
runtime.LockOSThread()
83+
tidInner = syscall.Gettid()
84+
}
85+
86+
activeDevice, err := icicle_runtime.GetActiveDevice()
87+
assert.Equal(t, icicle_runtime.Success, err)
88+
assert.Equal(t, gpuDevice, *activeDevice)
89+
90+
close(done)
91+
}()
92+
93+
<-done
94+
95+
cpuDevice := icicle_runtime.CreateDevice("CPU", 0)
96+
icicle_runtime.SetDefaultDevice(&cpuDevice)
4497
}
4598

4699
func TestRegisteredDevices(t *testing.T) {
47-
err := runtime.LoadBackendFromEnvOrDefault()
48-
assert.Equal(t, runtime.Success, err)
49-
devices, _ := runtime.GetRegisteredDevices()
100+
devices, _ := icicle_runtime.GetRegisteredDevices()
50101
assert.Equal(t, []string{"CUDA", "CPU"}, devices)
51102
}
52103

53104
func TestDeviceProperties(t *testing.T) {
54-
_, err := runtime.GetDeviceProperties()
55-
assert.Equal(t, runtime.Success, err)
105+
_, err := icicle_runtime.GetDeviceProperties()
106+
assert.Equal(t, icicle_runtime.Success, err)
56107
}
57108

58109
func TestActiveDevice(t *testing.T) {
59-
runtime.SetDevice(&DEVICE)
60-
activeDevice, err := runtime.GetActiveDevice()
61-
assert.Equal(t, runtime.Success, err)
62-
assert.Equal(t, DEVICE, *activeDevice)
63-
memory1, err := runtime.GetAvailableMemory()
64-
if err == runtime.ApiNotImplemented {
65-
t.Skipf("GetAvailableMemory() function is not implemented on %s device", DEVICE.GetDeviceType())
110+
devCpu := icicle_runtime.CreateDevice("CUDA", 0)
111+
icicle_runtime.SetDevice(&devCpu)
112+
activeDevice, err := icicle_runtime.GetActiveDevice()
113+
assert.Equal(t, icicle_runtime.Success, err)
114+
assert.Equal(t, devCpu, *activeDevice)
115+
memory1, err := icicle_runtime.GetAvailableMemory()
116+
if err == icicle_runtime.ApiNotImplemented {
117+
t.Skipf("GetAvailableMemory() function is not implemented on %s device", devCpu.GetDeviceType())
66118
}
67-
assert.Equal(t, runtime.Success, err)
119+
assert.Equal(t, icicle_runtime.Success, err)
68120
assert.Greater(t, memory1.Total, uint(0))
69121
assert.Greater(t, memory1.Free, uint(0))
70122
}

wrappers/golang/runtime/tests/main_test.go

+1-13
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,7 @@ import (
66
"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
77
)
88

9-
var DEVICE runtime.Device
10-
119
func TestMain(m *testing.M) {
1210
runtime.LoadBackendFromEnvOrDefault()
13-
devices, e := runtime.GetRegisteredDevices()
14-
if e != runtime.Success {
15-
panic("Failed to load registered devices")
16-
}
17-
for _, deviceType := range devices {
18-
DEVICE = runtime.CreateDevice(deviceType, 0)
19-
runtime.SetDevice(&DEVICE)
20-
21-
// execute tests
22-
m.Run()
23-
}
11+
m.Run()
2412
}

0 commit comments

Comments
 (0)