API for TensorIDs in GradStore for O(num_relevant_vars) backward_step #2377
Closed
spaghetti-source
started this conversation in
Ideas
Replies: 1 comment 1 reply
-
Sounds good, feel free to make a PR adding |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello folks.
I implemented word2vec in candle, but it was very slow. The reason was that the candle_nn's
SGD::step
takes O(num_variables) regardless of the number of relevant variables. (rem: In word2vec training, we take words in sentences and update only their embeddings.) Below is the relevant code fragment in candle.To mitigate this issue, I propose to add an API to retrieve a list of
TensorId
in theGradStore
. The following is my suggested implementation.By having this API, we can implement our own SGD as follows to mitigate the issue.
This provides a significant speed-up in use cases that update only relevant variables like word2vec. I implemented a simple benchmark (https://gist.github.com/spaghetti-source/f0630f1d0ad1b98f736a1d8e9719ff6d) and observed the following speed-up on my local computer.
Beta Was this translation helpful? Give feedback.
All reactions