@@ -64,9 +64,11 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
64
64
std::string selected_device;
65
65
std::vector<std::string> luid_list;
66
66
std::string device_mode = " " ;
67
+ std::map<std::string, std::string> ov_luid_map;
67
68
68
69
if (provider_options.contains (" device_type" )) {
69
70
selected_device = provider_options.at (" device_type" );
71
+ std::erase (selected_device, ' ' );
70
72
if (selected_device == " AUTO" ) return selected_device;
71
73
72
74
if (auto delimit = selected_device.find (" :" ); delimit != std::string::npos) {
@@ -105,13 +107,16 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
105
107
#endif
106
108
}
107
109
110
+ // Get the LUID passed from the provider option in a comma seperated string list
111
+ // Compare each of the LUID's against the LUID obtained using ov property and map with the right device
108
112
if (provider_options.contains (" device_luid" )) {
109
113
std::string luid_str = provider_options.at (" device_luid" );
110
- ORT_ENFORCE (selected_device.find (" GPU" ) != std::string::npos, " LUID is supported only for GPU" );
111
114
std::erase (luid_str, ' ' );
112
115
luid_list = split (luid_str, ' ,' );
113
116
}
117
+
114
118
bool all_devices_found = true ;
119
+
115
120
for (auto device : devices_to_check) {
116
121
bool device_found = false ;
117
122
// Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place
@@ -124,36 +129,18 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
124
129
// Check if device index is appended (.0, .1, etc.), if so, remove it
125
130
if (auto delimit = device_prefix.find (" ." ); delimit != std::string::npos)
126
131
device_prefix = device_prefix.substr (0 , delimit);
127
-
128
132
if (supported_device_types.contains (device_prefix)) {
129
133
try {
130
134
std::vector<std::string> available_devices = ov_core->GetAvailableDevices (device_prefix);
131
135
// Here we need to find the full device name (with .idx, but without _precision)
132
136
if (std::find (std::begin (available_devices), std::end (available_devices), device) != std::end (available_devices))
133
137
device_found = true ;
134
- if (device_prefix == " GPU" && luid_list.size () > 0 ) {
135
- std::map<std::string, std::string> ov_luid_map;
136
- for (auto gpu_dev : available_devices) {
137
- ov::device::LUID ov_luid = OVCore::Get ()->core .get_property (gpu_dev, ov::device::luid);
138
+ if (device_prefix != " CPU" && luid_list.size () > 0 ) {
139
+ for (auto dev : available_devices) {
140
+ ov::device::LUID ov_luid = OVCore::Get ()->core .get_property (dev, ov::device::luid);
138
141
std::stringstream ov_luid_str;
139
142
ov_luid_str << ov_luid;
140
- ov_luid_map.emplace (ov_luid_str.str (), gpu_dev);
141
- }
142
- for (auto luid_str : luid_list) {
143
- if (ov_luid_map.contains (luid_str)) {
144
- auto ov_dev = ov_luid_map.at (luid_str);
145
- if (!device_mode.empty ()) {
146
- selected_device = device_mode + " :" + ov_dev;
147
- for (auto dev_str : devices_to_check) {
148
- if (dev_str.find (" GPU" ) != std::string::npos)
149
- selected_device = selected_device + " ," + dev_str;
150
- }
151
- } else {
152
- selected_device = ov_dev;
153
- }
154
- } else {
155
- ORT_THROW (" Invalid device_luid is set" );
156
- }
143
+ ov_luid_map.emplace (ov_luid_str.str (), dev);
157
144
}
158
145
}
159
146
} catch (const char * msg) {
@@ -162,6 +149,27 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
162
149
}
163
150
all_devices_found = all_devices_found && device_found;
164
151
}
152
+ if (luid_list.size () > 0 ) {
153
+ std::string ov_luid_devices;
154
+ for (auto luid_str : luid_list) {
155
+ if (ov_luid_map.contains (luid_str)) {
156
+ if (!ov_luid_devices.empty ()) ov_luid_devices = ov_luid_devices + " ," ;
157
+ ov_luid_devices = ov_luid_devices + ov_luid_map.at (luid_str);
158
+ } else {
159
+ ORT_THROW (" Invalid device_luid is set" );
160
+ }
161
+ }
162
+ if (!device_mode.empty ()) {
163
+ selected_device = device_mode + " :" + ov_luid_devices;
164
+ for (auto dev_str : devices_to_check) {
165
+ auto default_dev = split (dev_str, ' .' )[0 ];
166
+ if (ov_luid_devices.find (default_dev) == std::string::npos)
167
+ selected_device = selected_device + " ," + dev_str;
168
+ }
169
+ } else {
170
+ selected_device = ov_luid_devices;
171
+ }
172
+ }
165
173
// If invalid device is chosen error is thrown
166
174
if (!all_devices_found)
167
175
ORT_THROW (
0 commit comments