Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][TPU] Enable Top K #15489

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Mar 25, 2025

Enabling the topk optimization that was introduced in #15242.

Currently facing the very issue foreseen by @njhill here #15242 (comment).

ERROR 03-25 18:27:23 [core.py:343]     random_sampled = self.topk_topp_sampler(
ERROR 03-25 18:27:23 [core.py:343]                      ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 03-25 18:27:23 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-25 18:27:23 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 03-25 18:27:23 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-25 18:27:23 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/vllm/v1/sample/ops/topk_topp_sampler.py", line 119, in forward_tpu
ERROR 03-25 18:27:23 [core.py:343]     topk_values, topk_indices = torch.topk(logits, k, dim=-1)
ERROR 03-25 18:27:23 [core.py:343]                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343] TypeError: topk(): argument 'k' (position 2) must be int, not Tensor

Dumping work for reference, will look into it asap.

Update:

For completeness, I've run microbenchmarks and the new impl is slower (but of course correct):

//before
Running 32 elapsed time: 0.0018310546875
Running 32 elapsed time: 0.0017833709716796875
// after
 Running 32 elapsed time: 0.003275632858276367
Running 32 elapsed time: 0.003297090530395508

cc @hyeygit

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 25, 2025
Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@mergify mergify bot removed the needs-rebase label Mar 26, 2025
@NickLucche NickLucche marked this pull request as ready for review March 26, 2025 10:38
@NickLucche NickLucche marked this pull request as draft March 26, 2025 11:14
@NickLucche
Copy link
Contributor Author

Let's hold until main is fixed to reduce entropy

Copy link

mergify bot commented Mar 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added tpu Related to Google TPUs needs-rebase labels Mar 27, 2025
@hyeygit
Copy link
Contributor

hyeygit commented Mar 27, 2025

Thank you @NickLucche for this PR and thank you @njhill for fixing the batch case for top-k! In #15242 I only tested with the microbenchmark and test_sampler.py (scalar case only), not realizing that k can be batched. Thank you for the catch and sorry for the miss.

One thing to note is that on TPU torch.topk still involves a full vocab sort (see XLA lowering). The reason using torch.topk is so much faster on TPU is because of avoiding a full-vocab torch.scatter (as used in apply_top_k_top_p) which is extremely slow on TPU.

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
sync
Signed-off-by: NickLucche <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants