Skip to content

Commit a456f2d

Browse files
committed
generalise LUID for GPU and NPU
1 parent cfd6932 commit a456f2d

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

+31-23
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
6464
std::string selected_device;
6565
std::vector<std::string> luid_list;
6666
std::string device_mode = "";
67+
std::map<std::string, std::string> ov_luid_map;
6768

6869
if (provider_options.contains("device_type")) {
6970
selected_device = provider_options.at("device_type");
71+
std::erase(selected_device, ' ');
7072
if (selected_device == "AUTO") return selected_device;
7173

7274
if (auto delimit = selected_device.find(":"); delimit != std::string::npos) {
@@ -105,13 +107,16 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
105107
#endif
106108
}
107109

110+
// Get the LUID passed from the provider option in a comma seperated string list
111+
// Compare each of the LUID's against the LUID obtained using ov property and map with the right device
108112
if (provider_options.contains("device_luid")) {
109113
std::string luid_str = provider_options.at("device_luid");
110-
ORT_ENFORCE(selected_device.find("GPU") != std::string::npos, "LUID is supported only for GPU");
111114
std::erase(luid_str, ' ');
112115
luid_list = split(luid_str, ',');
113116
}
117+
114118
bool all_devices_found = true;
119+
115120
for (auto device : devices_to_check) {
116121
bool device_found = false;
117122
// Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place
@@ -124,36 +129,18 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
124129
// Check if device index is appended (.0, .1, etc.), if so, remove it
125130
if (auto delimit = device_prefix.find("."); delimit != std::string::npos)
126131
device_prefix = device_prefix.substr(0, delimit);
127-
128132
if (supported_device_types.contains(device_prefix)) {
129133
try {
130134
std::vector<std::string> available_devices = ov_core->GetAvailableDevices(device_prefix);
131135
// Here we need to find the full device name (with .idx, but without _precision)
132136
if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices))
133137
device_found = true;
134-
if (device_prefix == "GPU" && luid_list.size() > 0) {
135-
std::map<std::string, std::string> ov_luid_map;
136-
for (auto gpu_dev : available_devices) {
137-
ov::device::LUID ov_luid = OVCore::Get()->core.get_property(gpu_dev, ov::device::luid);
138+
if (device_prefix != "CPU" && luid_list.size() > 0) {
139+
for (auto dev : available_devices) {
140+
ov::device::LUID ov_luid = OVCore::Get()->core.get_property(dev, ov::device::luid);
138141
std::stringstream ov_luid_str;
139142
ov_luid_str << ov_luid;
140-
ov_luid_map.emplace(ov_luid_str.str(), gpu_dev);
141-
}
142-
for (auto luid_str : luid_list) {
143-
if (ov_luid_map.contains(luid_str)) {
144-
auto ov_dev = ov_luid_map.at(luid_str);
145-
if (!device_mode.empty()) {
146-
selected_device = device_mode + ":" + ov_dev;
147-
for (auto dev_str : devices_to_check) {
148-
if (dev_str.find("GPU") != std::string::npos)
149-
selected_device = selected_device + "," + dev_str;
150-
}
151-
} else {
152-
selected_device = ov_dev;
153-
}
154-
} else {
155-
ORT_THROW("Invalid device_luid is set");
156-
}
143+
ov_luid_map.emplace(ov_luid_str.str(), dev);
157144
}
158145
}
159146
} catch (const char* msg) {
@@ -162,6 +149,27 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
162149
}
163150
all_devices_found = all_devices_found && device_found;
164151
}
152+
if (luid_list.size() > 0) {
153+
std::string ov_luid_devices;
154+
for (auto luid_str : luid_list) {
155+
if (ov_luid_map.contains(luid_str)) {
156+
if (!ov_luid_devices.empty()) ov_luid_devices = ov_luid_devices + ",";
157+
ov_luid_devices = ov_luid_devices + ov_luid_map.at(luid_str);
158+
} else {
159+
ORT_THROW("Invalid device_luid is set");
160+
}
161+
}
162+
if (!device_mode.empty()) {
163+
selected_device = device_mode + ":" + ov_luid_devices;
164+
for (auto dev_str : devices_to_check) {
165+
auto default_dev = split(dev_str, '.')[0];
166+
if (ov_luid_devices.find(default_dev) == std::string::npos)
167+
selected_device = selected_device + "," + dev_str;
168+
}
169+
} else {
170+
selected_device = ov_luid_devices;
171+
}
172+
}
165173
// If invalid device is chosen error is thrown
166174
if (!all_devices_found)
167175
ORT_THROW(

0 commit comments

Comments
 (0)