-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Wrap tf variables in keras variables for TFSMLayer #20995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20995 +/- ##
==========================================
- Coverage 82.45% 82.43% -0.02%
==========================================
Files 562 562
Lines 53316 53404 +88
Branches 8259 8275 +16
==========================================
+ Hits 43960 44023 +63
- Misses 7340 7361 +21
- Partials 2016 2020 +4
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
keras/src/backend/tensorflow/core.py
Outdated
@@ -124,6 +127,8 @@ def _map_aggregation(self, aggregation): | |||
|
|||
|
|||
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): | |||
if isinstance(x, tf.Variable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shortcut needs to make sure the dtype
matches or is None
, otherwise it's incorrect.
I think you'll have to move this up:
if dtype is not None:
dtype = standardize_dtype(dtype)
and then do:
if isinstance(x, tf.Variable) and (dtype is None or x.dtype == dtype):
return x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that's why the unit tests are failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Actually looks like we can just stick to the main path in convert_to_tensor, as tf.is_tensor(variable) == True
. I think we were missing a conversion when comparing dtypes though, we were comparing tf dtype, to standardized string dtype in one spot.
Trying a fix.
Fixes #20955