-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cleanup/fix] Make inference parameters explicit, etc. #198
Conversation
xy12181
commented
Feb 19, 2025
•
edited
Loading
edited
- Rename max_num_seq to batch_size
- Fix a process runner typo
- ......
@@ -114,7 +114,7 @@ class LinearParallelConfig: | |||
@dataclasses.dataclass | |||
class RMSNormParallelConfig: | |||
mesh: jax.sharding.Mesh | |||
activation_shared: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhihaoshan-google: shar_d_ed, not shared. right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct! sorry for the mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks for the correction!
@vipannalla could you please help review this PR? |
- Rename max_num_seq to batch_size - Fix a process runner typo - ...... ``` $ python experimental/jax/inference/entrypoint/mini_offline_benchmarking.py Offline inference begins: 2025-02-19 18:51:29.910043 ... Offline inference ends: 2025-02-19 19:15:14.701318 Benchmarking result: Total requests: 24000 Total input tokens: 5311141 Total output tokens: 7030596 Input token thruput: 3727.18 tokens/sec Output token thruput: 4933.84 tokens/sec ```