-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
- Add exception handling for unsupported devices in `device_string_to_torch_device` method. - Correct logits concatenation in `OnnxEvaluator` by using `logits_dict` instead of `logits`. - Initialize `logits_dict` in `PyTorchEvaluator` to handle different result types. - Update `_inference` method in `PyTorchEvaluator` to handle different result types and concatenate logits correctly.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -196,7 +196,15 @@ def compute_throughput(metric: Metric, latencies: Any) -> MetricResult: | |
class _OliveEvaluator(OliveEvaluator): | ||
@staticmethod | ||
def device_string_to_torch_device(device: Device): | ||
return torch.device("cuda") if device == Device.GPU else torch.device(device) | ||
try: | ||
return ( | ||
torch.device("cuda") | ||
if device == Device.GPU | ||
else torch.device(device) | ||
) | ||
except (ValueError, TypeError, RuntimeError): | ||
logger.warning(f"Device '{device}' is not supported in torch, fallback to CPU instead.") | ||
Check warning Code scanning / lintrunner PYLINT/W1203 Warning
Use lazy % formatting in logging functions (logging-fstring-interpolation)
See logging-fstring-interpolation. Check warning Code scanning / lintrunner RUFF/G004 Warning
Logging statement uses f-string.
See https://docs.astral.sh/ruff/rules/logging-f-string |
||
return torch.device("cpu") | ||
|
||
@classmethod | ||
def io_bind_enabled(cls, metric: Metric, inference_settings: Dict) -> bool: | ||
|
@@ -462,7 +470,7 @@ def _inference( | |
if is_single_tensor_output: | ||
logits = torch.cat(logits, dim=0) | ||
else: | ||
logits = {k: torch.cat(logits[k], dim=0) for k in output_names} | ||
logits = {k: torch.cat(logits_dict[k], dim=0) for k in output_names} | ||
|
||
tuning_result_file = inference_settings.get("tuning_result_file") | ||
if tuning_result_file: | ||
|
@@ -736,6 +744,7 @@ def _inference( | |
preds = [] | ||
targets = [] | ||
logits = [] | ||
logits_dict = collections.defaultdict(list) | ||
device = _OliveEvaluator.device_string_to_torch_device(device) | ||
run_kwargs = metric.get_run_kwargs() | ||
if device: | ||
|
@@ -748,15 +757,20 @@ def _inference( | |
# it is expensive to convert to list and then convert back to torch tensor | ||
preds.append(outputs.cpu()) | ||
targets.append(labels.cpu()) | ||
logits.append( | ||
result.logits.cpu() | ||
if not isinstance(result, torch.Tensor) and getattr(result, "logits", None) is not None | ||
else result.cpu() | ||
) | ||
if isinstance(result, torch.Tensor): | ||
logits.append(result.cpu()) | ||
elif isinstance(result, (list, tuple)): | ||
logits.append([r.cpu() for r in result]) | ||
elif isinstance(result, dict): | ||
for k in result.keys(): | ||
Check warning Code scanning / lintrunner RUFF/SIM118 Warning
Use key in dict instead of key in dict.keys().
See https://docs.astral.sh/ruff/rules/in-dict-keys |
||
logits_dict[k].append(result[k].cpu()) | ||
# concatenate along the batch dimension | ||
preds = torch.cat(preds, dim=0) | ||
targets = torch.cat(targets, dim=0) | ||
logits = torch.cat(logits, dim=0) | ||
if not logits_dict: | ||
logits = torch.cat(logits, dim=0) | ||
else: | ||
logits = {k: torch.cat(logits_dict[k], dim=0) for k in logits_dict.keys()} | ||
Check notice Code scanning / lintrunner PYLINT/C0206 Note
Consider iterating with .items() (consider-using-dict-items)
See consider-using-dict-items. Check warning Code scanning / lintrunner RUFF/SIM118 Warning
Use key in dict instead of key in dict.keys().
See https://docs.astral.sh/ruff/rules/in-dict-keys |
||
# move model to cpu | ||
if device: | ||
session.to("cpu") | ||
|