Skip to content

Commit

Permalink
[Preserve dtype of array when converting to torch (#1349)
Browse files Browse the repository at this point in the history
We have noticing the following error with a recent version of outlines
when used with MLX:
```
TypeError: argument 'token_id': 'float' object cannot be interpreted as an integer

At:
  /.../outlines_core/fsm/guide.py(294): get_next_state
  /.../outlines/processors/structured.py(101): process_logits
  /.../outlines/processors/base_logits_processor.py(90): __call__
```

The issue is that the MLX array of tokens, which are integers, are being
force-converted to floats, even though outlines expects an integer
array. This is because all MLX arrays are being converted to `float32`,
even when it's not necessarily appropriate, like in this case. Looking
at the [commented
link](https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch),
the advice was to convert to `float32` only for `bfloat16`, because
numpy does not support `bfloat16`. Now the MLX `_to_torch`
implementation matches the other array libraries, none of the other
libraries are being force-casted to float
  • Loading branch information
neilmehta24 authored Jan 15, 2025
1 parent 7b9012b commit 088f439
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 3 additions & 3 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
import mlx.core as mx

# https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch
return torch.from_dlpack(
np.array(tensor_like.astype(mx.float32), copy=False)
)
if tensor_like.dtype == mx.bfloat16:
tensor_like = tensor_like.astype(mx.float32)
return torch.from_dlpack(np.array(tensor_like, copy=False))

elif is_jax_array_type(type(tensor_like)):
import jax
Expand Down
8 changes: 7 additions & 1 deletion tests/processors/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mlx.core as mx

arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16)
except ImportError:
pass

Expand Down Expand Up @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor):
torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
data = processor._from_torch(torch_tensor, type(arrays[array_type]))
assert isinstance(data, type(arrays[array_type]))
assert np.allclose(data, arrays[array_type])
if array_type == "mlx_bfloat16":
# For bfloat16, we expect the output to be float32 due to the conversion
assert data.dtype == mx.float32
assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32))
else:
assert np.allclose(data, arrays[array_type])


@pytest.mark.parametrize("array_type", arrays.keys())
Expand Down

0 comments on commit 088f439

Please sign in to comment.