Skip to content

Commit

Permalink
Fix device handling and logits concatenation in OliveEvaluator
Browse files Browse the repository at this point in the history
- Add exception handling for unsupported devices in `device_string_to_torch_device` method.
- Correct logits concatenation in `OnnxEvaluator` by using `logits_dict` instead of `logits`.
- Initialize `logits_dict` in `PyTorchEvaluator` to handle different result types.
- Update `_inference` method in `PyTorchEvaluator` to handle different result types and concatenate logits correctly.
  • Loading branch information
tezheng committed Feb 14, 2025
1 parent 00415b6 commit 09138a8
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
26 changes: 18 additions & 8 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def compute_throughput(metric: Metric, latencies: Any) -> MetricResult:
class _OliveEvaluator(OliveEvaluator):
@staticmethod
def device_string_to_torch_device(device: Device):
return torch.device("cuda") if device == Device.GPU else torch.device(device)
try:
return torch.device("cuda") if device == Device.GPU else torch.device(device)
except (ValueError, TypeError, RuntimeError):
logger.warning("Device %s is not supported in torch, fallback to CPU instead.", device)
return torch.device("cpu")

@classmethod
def io_bind_enabled(cls, metric: Metric, inference_settings: Dict) -> bool:
Expand Down Expand Up @@ -462,7 +466,7 @@ def _inference(
if is_single_tensor_output:
logits = torch.cat(logits, dim=0)
else:
logits = {k: torch.cat(logits[k], dim=0) for k in output_names}
logits = {k: torch.cat(logits_dict[k], dim=0) for k in output_names}

tuning_result_file = inference_settings.get("tuning_result_file")
if tuning_result_file:
Expand Down Expand Up @@ -736,6 +740,7 @@ def _inference(
preds = []
targets = []
logits = []
logits_dict = collections.defaultdict(list)
device = _OliveEvaluator.device_string_to_torch_device(device)
run_kwargs = metric.get_run_kwargs()
if device:
Expand All @@ -748,15 +753,20 @@ def _inference(
# it is expensive to convert to list and then convert back to torch tensor
preds.append(outputs.cpu())
targets.append(labels.cpu())
logits.append(
result.logits.cpu()
if not isinstance(result, torch.Tensor) and getattr(result, "logits", None) is not None
else result.cpu()
)
if isinstance(result, torch.Tensor):
logits.append(result.cpu())
elif isinstance(result, (list, tuple)):
logits.append([r.cpu() for r in result])
elif isinstance(result, dict):
for k in result:
logits_dict[k].append(result[k].cpu())
# concatenate along the batch dimension
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
logits = torch.cat(logits, dim=0)
if not logits_dict:
logits = torch.cat(logits, dim=0)
else:
logits = {k: torch.cat(logits_dict[k], dim=0) for k in logits_dict}
# move model to cpu
if device:
session.to("cpu")
Expand Down
60 changes: 60 additions & 0 deletions test/unit_test/evaluator/test_olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
OpenVINOEvaluator,
PyTorchEvaluator,
SNPEEvaluator,
_OliveEvaluator,
)
from olive.exception import OliveEvaluationError
from olive.hardware.accelerator import Device
Expand Down Expand Up @@ -404,6 +405,65 @@ def test_evaluator_get_inference_session(self, metric_inference_settings, model_
assert metric.get_inference_settings("onnx") == metric_inference_settings
assert model.inference_settings == model_inference_settings

@pytest.mark.parametrize(
("input_device", "torch_device"),
[("cpu", "cpu"), ("gpu", "cuda"), ("cuda", "cuda"), ("npu", "cpu"), ("unknown device", "cpu")],
)
def test_evaluator_unknown_device(self, input_device, torch_device):
assert _OliveEvaluator.device_string_to_torch_device(input_device).type == torch_device

@patch("onnxruntime.InferenceSession")
def test_onnx_evaluator_inference_dict_output(self, inference_session_mock):
mock_session = MagicMock()
mock_session.get_providers.return_value = ["CPUExecutionProvider"]
mock_session.run.return_value = ([0.5, 0.6], [0.6, 0.7])
inference_session_mock.return_value = mock_session

model = get_onnx_model()
# pylint: disable=protected-access
model._io_config = {
"input_names": ["dummy"],
"input_types": ["int32"],
"output_names": ["output_1", "output_2"],
}
metric = get_accuracy_metric(AccuracySubType.PRECISION)
dataloader, *_ = OliveEvaluator.get_user_config(model.framework, metric)
evaluator = OnnxEvaluator()

# pylint: disable=protected-access
outputs, *_ = evaluator._inference(
model,
metric,
dataloader,
lambda x: x["output_1"],
)
mock_session.run.assert_called_once()
assert set(outputs.logits.keys()) == set(model.io_config["output_names"])

def test_torch_evaluator_inference_dict_output(self):
import torch

model = MagicMock()
model.run_session.return_value = {
"output_1": torch.Tensor([0.5, 0.6]),
"output_2": torch.Tensor([0.6, 0.7]),
}
metric = MagicMock()
metric.get_run_kwargs.return_value = {}
dataloader = MagicMock()
dataloader.__iter__.return_value = [(torch.Tensor([1, 1]), torch.Tensor([0]))]
evaluator = PyTorchEvaluator()

# pylint: disable=protected-access
outputs, *_ = evaluator._inference(
model,
metric,
dataloader,
lambda x: x["output_1"],
)
model.run_session.assert_called_once()
assert set(outputs.logits.keys()) == {"output_1", "output_2"}


class TestOliveEvaluatorConfig:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 09138a8

Please sign in to comment.