Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output of torch.sum with unsigned input should be unsigned #242

Open
mdhaber opened this issue Jan 24, 2025 · 3 comments
Open

Output of torch.sum with unsigned input should be unsigned #242

mdhaber opened this issue Jan 24, 2025 · 3 comments

Comments

@mdhaber
Copy link

mdhaber commented Jan 24, 2025

According to the standard, the documentation of sum states for the dtype parameter:

If None, the returned array must have the same data type as x, unless x has an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... if x has an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.

If I understand correctly, then the sums for unsigned dtype below should have uint64 dtype:

from array_api_compat import torch as xp
for dtype in [xp.int8, xp.int16, xp.int32, xp.int64,
              xp.uint8, xp.uint16, xp.uint32, xp.uint64,
              xp.float32, xp.float64, xp.complex32, xp.complex64]:
    x = xp.asarray([1, 2, 3], dtype=dtype)
    try:
        print(xp.sum(x).dtype, dtype)
    except RuntimeError as e:
        print(e)

But the output is:

torch.int64 torch.int8
torch.int64 torch.int16
torch.int64 torch.int32
torch.int64 torch.int64
torch.int64 torch.uint8
torch.int64 torch.uint16
torch.int64 torch.uint32
torch.int64 torch.uint64
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

I think this is at least partially fixable within array-api-compat.

Also, torch doesn't seem to natively support sum for most uint dtypes or complex32. If we change xp.sum(x) to xp.sum(x, dtype=dtype) in the code above, the output is:

torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

It would be helpful if array-api-compat would implement sum for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64 than with uint64, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)

Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?

@lucascolley
Copy link
Member

Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries?

So far, I think things have just been reported upstream on an ad-hoc basis.

@ev-br
Copy link
Member

ev-br commented Jan 26, 2025

+1 to reporting it upstream. Ideally, there's an upstream issue, plus a reference to in a workaround in -compat.

EDIT: This isn't a hard requirement, of course.

@mdhaber
Copy link
Author

mdhaber commented Jan 29, 2025

Looks like support for uints is already tracked at pytorch/pytorch#58743. It happens to be near the top. Oh, and it has its own issue, pytorch/pytorch#58734.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants