Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added jetstream_total_tokens_in_current_batch metric #128

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions jetstream/core/metrics/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,12 @@ def get_request_output_length(self):

def get_request_success_count_metric(self):
return self._request_success_count.labels(id=self._id)

_total_tokens_in_current_batch = Gauge(
name="jetstream_total_tokens_in_current_batch",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have padding in each batch, I feel this metric is not critical or useful. We more care about total response tokens per request. @JoeZijunZhou Please share your thoughts on this metrics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up in the internal doc. The GKE team added this metric in the design doc.

documentation="Total number of tokens in the decode batch",
labelnames=["id", "idx"],
)

def get_total_tokens_in_current_batch_metric(self, idx: int):
return self._total_tokens_in_current_batch.labels(id=self._id, idx=idx)
8 changes: 8 additions & 0 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ def _detokenize_thread(self, idx: int):
# is a result tokens, and we can't annotate the tuple.
result_tokens = result_tokens.convert_to_numpy()

total_tokens_in_batch = 0
for slot, request in my_live_requests.items():
if request is not None:
results, complete = token_utils.process_result_tokens(
Expand All @@ -826,6 +827,9 @@ def _detokenize_thread(self, idx: int):
complete=request.complete,
)
request.complete = complete
total_tokens_in_batch += result_tokens.get_result_at_slot(
slot
).lengths
# Return some output samples.
request.enqueue_samples(results)
if request.complete.all():
Expand Down Expand Up @@ -873,6 +877,10 @@ def _detokenize_thread(self, idx: int):
generate_timestep_added,
(time.time() - start_detokenize_time) * 10**3,
)
if self._metrics_collector:
self._metrics_collector.get_total_tokens_in_current_batch_metric(
idx=idx
).set(total_tokens_in_batch)
else:
# We want to update a slot with the new channel.
slot, active_request = data
Expand Down
Loading