Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved stability of litellm models for reasoning models. #538

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
Expand Down Expand Up @@ -93,20 +94,25 @@ def __init__(self, config, env_config) -> None:
litellm.drop_params = True
litellm.set_verbose = False

def is_reasoning_model(self):
return "o1" in self.model or "o3" in self.model or "R1" in self.model

def _prepare_stop_sequence(self, stop_sequence):
"""Prepare and validate stop sequence."""
if self.provider == "anthropic":
# Filter out whitespace-only stop sequences
if stop_sequence:
stop_sequence = [s for s in stop_sequence if s and s.strip()]
if not stop_sequence: # If empty after filtering
stop_sequence = ["\n"]
return stop_sequence

def _prepare_max_new_tokens(self, max_new_tokens):
"""Calculate completion tokens based on max_new_tokens."""
if not max_new_tokens or max_new_tokens <= 0:
return None

if "o1" in self.model:
if self.is_reasoning_model():
# We need to allow more tokens to include reasoning tokens
max_new_tokens = min(max_new_tokens * 10, 32000)
return max_new_tokens
Expand All @@ -132,8 +138,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
"n": num_samples,
"caching": True,
}
if "o1" in self.model:
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
if self.is_reasoning_model():
logger.warning("Reasoning models do not support temperature, top_p, stop sequence. Disabling.")
else:
kwargs["temperature"] = self.TEMPERATURE
kwargs["top_p"] = self.TOP_P
Expand All @@ -142,10 +148,17 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
response = litellm.completion(**kwargs)

# If response is empty, retry without caching (maybe the error is recoverable and solved with a retry)
if response.choices[0].message.content is None:
content = response.choices[0].message.content
if not content:
kwargs["caching"] = False
logger.info("Response is empty, retrying without caching")
response = litellm.completion(**kwargs)

if content is not None and "<think>" in content:
logger.debug(f"Removing <think> tags from response: {content}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing think tags from the answer here ? I think it should be done in the metric function no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are evaluating a reasoning model the grader will look at the thinking tokens unless we remove them. We would need to remove them in every metric function otherwise.

response.choices[0].message.content = re.sub(
r"<think>.*?</think>", "", content, flags=re.DOTALL
).strip()
return response
except litellm.BadRequestError as e:
if "message" in e.__dict__:
Expand Down