Skip to content

Commit

Permalink
Fix device handling and logits concatenation in OliveEvaluator
Browse files Browse the repository at this point in the history
- Add exception handling for unsupported devices in `device_string_to_torch_device` method.
- Correct logits concatenation in `OnnxEvaluator` by using `logits_dict` instead of `logits`.
- Initialize `logits_dict` in `PyTorchEvaluator` to handle different result types.
- Update `_inference` method in `PyTorchEvaluator` to handle different result types and concatenate logits correctly.
  • Loading branch information
tezheng committed Feb 14, 2025
1 parent 00415b6 commit 306da8e
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,15 @@ def compute_throughput(metric: Metric, latencies: Any) -> MetricResult:
class _OliveEvaluator(OliveEvaluator):
@staticmethod
def device_string_to_torch_device(device: Device):
return torch.device("cuda") if device == Device.GPU else torch.device(device)
try:
return (
torch.device("cuda")
if device == Device.GPU
else torch.device(device)
)
except (ValueError, TypeError, RuntimeError):
logger.warning(f"Device '{device}' is not supported in torch, fallback to CPU instead.")

Check warning

Code scanning / lintrunner

PYLINT/W1203 Warning

Use lazy % formatting in logging functions (logging-fstring-interpolation)
See logging-fstring-interpolation.

Check warning

Code scanning / lintrunner

RUFF/G004 Warning

Logging statement uses f-string.
See https://docs.astral.sh/ruff/rules/logging-f-string
return torch.device("cpu")

@classmethod
def io_bind_enabled(cls, metric: Metric, inference_settings: Dict) -> bool:
Expand Down Expand Up @@ -462,7 +470,7 @@ def _inference(
if is_single_tensor_output:
logits = torch.cat(logits, dim=0)
else:
logits = {k: torch.cat(logits[k], dim=0) for k in output_names}
logits = {k: torch.cat(logits_dict[k], dim=0) for k in output_names}

tuning_result_file = inference_settings.get("tuning_result_file")
if tuning_result_file:
Expand Down Expand Up @@ -736,6 +744,7 @@ def _inference(
preds = []
targets = []
logits = []
logits_dict = collections.defaultdict(list)
device = _OliveEvaluator.device_string_to_torch_device(device)
run_kwargs = metric.get_run_kwargs()
if device:
Expand All @@ -748,15 +757,20 @@ def _inference(
# it is expensive to convert to list and then convert back to torch tensor
preds.append(outputs.cpu())
targets.append(labels.cpu())
logits.append(
result.logits.cpu()
if not isinstance(result, torch.Tensor) and getattr(result, "logits", None) is not None
else result.cpu()
)
if isinstance(result, torch.Tensor):
logits.append(result.cpu())
elif isinstance(result, (list, tuple)):
logits.append([r.cpu() for r in result])
elif isinstance(result, dict):
for k in result.keys():

Check warning

Code scanning / lintrunner

RUFF/SIM118 Warning

Use key in dict instead of key in dict.keys().
See https://docs.astral.sh/ruff/rules/in-dict-keys
logits_dict[k].append(result[k].cpu())
# concatenate along the batch dimension
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
logits = torch.cat(logits, dim=0)
if not logits_dict:
logits = torch.cat(logits, dim=0)
else:
logits = {k: torch.cat(logits_dict[k], dim=0) for k in logits_dict.keys()}

Check notice

Code scanning / lintrunner

PYLINT/C0206 Note

Consider iterating with .items() (consider-using-dict-items)
See consider-using-dict-items.

Check warning

Code scanning / lintrunner

RUFF/SIM118 Warning

Use key in dict instead of key in dict.keys().
See https://docs.astral.sh/ruff/rules/in-dict-keys
# move model to cpu
if device:
session.to("cpu")
Expand Down

0 comments on commit 306da8e

Please sign in to comment.