From 306da8e084d2aff06507a9ceb8ade38b1f9f51bc Mon Sep 17 00:00:00 2001 From: Zheng Te <1221537+tezheng@users.noreply.github.com> Date: Fri, 14 Feb 2025 14:50:48 +0800 Subject: [PATCH] Fix device handling and logits concatenation in OliveEvaluator - Add exception handling for unsupported devices in `device_string_to_torch_device` method. - Correct logits concatenation in `OnnxEvaluator` by using `logits_dict` instead of `logits`. - Initialize `logits_dict` in `PyTorchEvaluator` to handle different result types. - Update `_inference` method in `PyTorchEvaluator` to handle different result types and concatenate logits correctly. --- olive/evaluator/olive_evaluator.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 2503503fb..14f4c9dc6 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -196,7 +196,15 @@ def compute_throughput(metric: Metric, latencies: Any) -> MetricResult: class _OliveEvaluator(OliveEvaluator): @staticmethod def device_string_to_torch_device(device: Device): - return torch.device("cuda") if device == Device.GPU else torch.device(device) + try: + return ( + torch.device("cuda") + if device == Device.GPU + else torch.device(device) + ) + except (ValueError, TypeError, RuntimeError): + logger.warning(f"Device '{device}' is not supported in torch, fallback to CPU instead.") + return torch.device("cpu") @classmethod def io_bind_enabled(cls, metric: Metric, inference_settings: Dict) -> bool: @@ -462,7 +470,7 @@ def _inference( if is_single_tensor_output: logits = torch.cat(logits, dim=0) else: - logits = {k: torch.cat(logits[k], dim=0) for k in output_names} + logits = {k: torch.cat(logits_dict[k], dim=0) for k in output_names} tuning_result_file = inference_settings.get("tuning_result_file") if tuning_result_file: @@ -736,6 +744,7 @@ def _inference( preds = [] targets = [] logits = [] + logits_dict = collections.defaultdict(list) device = _OliveEvaluator.device_string_to_torch_device(device) run_kwargs = metric.get_run_kwargs() if device: @@ -748,15 +757,20 @@ def _inference( # it is expensive to convert to list and then convert back to torch tensor preds.append(outputs.cpu()) targets.append(labels.cpu()) - logits.append( - result.logits.cpu() - if not isinstance(result, torch.Tensor) and getattr(result, "logits", None) is not None - else result.cpu() - ) + if isinstance(result, torch.Tensor): + logits.append(result.cpu()) + elif isinstance(result, (list, tuple)): + logits.append([r.cpu() for r in result]) + elif isinstance(result, dict): + for k in result.keys(): + logits_dict[k].append(result[k].cpu()) # concatenate along the batch dimension preds = torch.cat(preds, dim=0) targets = torch.cat(targets, dim=0) - logits = torch.cat(logits, dim=0) + if not logits_dict: + logits = torch.cat(logits, dim=0) + else: + logits = {k: torch.cat(logits_dict[k], dim=0) for k in logits_dict.keys()} # move model to cpu if device: session.to("cpu")