-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Apply backend.result_type
to bincount
, substract
, matmul
, multiply
, mean
and max
#18534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
All tests passed except the failure in Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18534 +/- ##
==========================================
+ Coverage 77.58% 77.67% +0.09%
==========================================
Files 334 334
Lines 32211 32302 +91
Branches 6286 6297 +11
==========================================
+ Hits 24990 25092 +102
+ Misses 5636 5631 -5
+ Partials 1585 1579 -6
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
dtype = getattr(x, "dtype", None) | ||
if hasattr(dtype, "name") and "float" in dtype.name: | ||
return cast(outputs, dtype) | ||
compute_dtype = dtypes.result_type(x.dtype, "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be result_type(x.dtype, config.floatx())
rather than hardcoding float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a note:
# `jnp.mean` does not handle low precision (e.g., float16) overflow
# correctly, so we compute with float32 and cast back to the original type.
it should come from this PR: keras-team/keras-core#410
I have added a test to verify the overflow behavior:
# test overflow
x = np.array([65504, 65504, 65504], dtype="float16")
self.assertAllClose(knp.mean(x), np.mean(x))
np.mean(x) |
jnp.mean(x) |
jnp.mean(x, dtype="float32") |
tfnp.mean(x) |
tfnp.mean(x, dtype="float32") |
torch.mean(x) |
torch.mean(x, dtype=torch.float32) |
---|---|---|---|---|---|---|
65504 | inf | 65504 | inf | 65504 | inf | 65504 |
As a result, we should use float32 for jax, tensorflow and torch to compute mean, even if backend.floatx() == "float16"
keras/backend/numpy/numpy.py
Outdated
return np.mean(x, axis=axis, keepdims=keepdims) | ||
x = convert_to_tensor(x) | ||
ori_dtype = standardize_dtype(x.dtype) | ||
compute_dtype = dtypes.result_type(x.dtype, "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise here
return tfnp.mean(x, axis=axis, keepdims=keepdims) | ||
x = convert_to_tensor(x) | ||
ori_dtype = standardize_dtype(x.dtype) | ||
compute_dtype = dtypes.result_type(x.dtype, "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise here
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x | ||
return torch.mean(x, axis=axis, keepdims=keepdims) | ||
ori_dtype = standardize_dtype(x.dtype) | ||
compute_dtype = dtypes.result_type(x.dtype, "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the great contribution! LGTM.
Followed by #18482
This PR has applied
result_type
to the following ops:bincount
substract
matmul
multiply
mean
max
The corresponding unit tests have also been added.