You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
According to the standard, the documentation of sum states for the dtype parameter:
If None, the returned array must have the same data type as x, unless x has an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... if x has an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.
If I understand correctly, then the sums for unsigned dtype below should have uint64 dtype:
I think this is at least partially fixable within array-api-compat.
Also, torch doesn't seem to natively support sum for most uint dtypes or complex32. If we change xp.sum(x) to xp.sum(x, dtype=dtype) in the code above, the output is:
torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64
It would be helpful if array-api-compat would implement sum for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64 than with uint64, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)
Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?
The text was updated successfully, but these errors were encountered:
According to the standard, the documentation of
sum
states for thedtype
parameter:If I understand correctly, then the sums for unsigned dtype below should have
uint64
dtype:But the output is:
I think this is at least partially fixable within
array-api-compat
.Also,
torch
doesn't seem to natively supportsum
for mostuint
dtypes orcomplex32
. If we changexp.sum(x)
toxp.sum(x, dtype=dtype)
in the code above, the output is:It would be helpful if
array-api-compat
would implementsum
for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow withint64
than withuint64
, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?
The text was updated successfully, but these errors were encountered: