Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Preserve dtype of array when converting to torch (#1349)
We have noticing the following error with a recent version of outlines when used with MLX: ``` TypeError: argument 'token_id': 'float' object cannot be interpreted as an integer At: /.../outlines_core/fsm/guide.py(294): get_next_state /.../outlines/processors/structured.py(101): process_logits /.../outlines/processors/base_logits_processor.py(90): __call__ ``` The issue is that the MLX array of tokens, which are integers, are being force-converted to floats, even though outlines expects an integer array. This is because all MLX arrays are being converted to `float32`, even when it's not necessarily appropriate, like in this case. Looking at the [commented link](https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch), the advice was to convert to `float32` only for `bfloat16`, because numpy does not support `bfloat16`. Now the MLX `_to_torch` implementation matches the other array libraries, none of the other libraries are being force-casted to float
- Loading branch information