-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
numpy to keras ops #2126
base: master
Are you sure you want to change the base?
numpy to keras ops #2126
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please format the code!
@@ -6,8 +6,10 @@ | |||
|
|||
try: | |||
import tensorflow as tf | |||
import keras.ops as ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason to import keras conditionally. We don't want to assume tf is install, we can assume keras is installed.
@@ -242,4 +248,4 @@ def get_config(self): | |||
"max_audio_length": self.max_audio_length, | |||
} | |||
) | |||
return config | |||
return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't remove training newline
@@ -84,6 +86,20 @@ def audio_shape(self): | |||
"""Returns the preprocessed size of a single audio sample.""" | |||
return (self.max_audio_length, self.num_mels) | |||
|
|||
|
|||
|
|||
def _get_rfftfreq_keras(self): # Inside the class definition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this comment support to mean?
@@ -92,25 +108,15 @@ def _get_mel_filters(self): | |||
|
|||
# TODO: Convert to TensorFlow ops (if possible). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove this TODO I think.
logstep * (mels[log_t] - min_log_mel) | ||
) | ||
|
||
freqs = ops.where(log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs) # using tf.where for conditional replacement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the comment, and reformat to 80 chars
mel_f = freqs | ||
|
||
fdiff = np.diff(mel_f) | ||
ramps = np.subtract.outer(mel_f, fftfreqs) | ||
fdiff = mel_f[1:] - mel_f[:-1] #keras diff. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove these comments
The mel_filters function was successfully migrated to Keras ops, replacing NumPy. All tests now pass, ensuring consistent performance within the Keras framework.