Skip to content

[CODEX-#369]: Refactor is_bad_response to take TLMOptions rather than a TLM model #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Update response validation methods to use the TLM endpoint in the Codex backend rather than a TLM model.
- Update response validation methods to support accepting and propagating TLMOptions from the cleanlab_tlm library.

## [1.0.1] - 2025-02-26

- Updates to logic for `is_unhelpful_response` util method.
Expand Down
152 changes: 120 additions & 32 deletions src/cleanlab_codex/response_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from __future__ import annotations

from typing import (
TYPE_CHECKING as _TYPE_CHECKING,
)
from typing import (
Any,
Callable,
Expand All @@ -13,13 +16,33 @@
cast,
)

from codex import AuthenticationError, BadRequestError
from pydantic import BaseModel, ConfigDict, Field

from cleanlab_codex.internal.sdk_client import (
MissingAuthKeyError,
client_from_access_key,
client_from_api_key,
is_access_key,
)
from cleanlab_codex.internal.utils import generate_pydantic_model_docstring
from cleanlab_codex.types.tlm import TLM
from cleanlab_codex.utils.errors import MissingDependencyError
from cleanlab_codex.utils.prompt import default_format_prompt

if _TYPE_CHECKING:
from cleanlab_tlm.tlm import TLMOptions
from codex import Codex as _Codex

from cleanlab_codex.types.tlm import TlmScoreResponse


class MissingAuthError(ValueError):
"""Raised when no API key or access key is provided and untrustworthy or unhelpfulness checks are run."""

def __str__(self) -> str:
return "A valid Codex API key or access key must be provided when using the TLM for untrustworthy or unhelpfulness checks."


_DEFAULT_FALLBACK_ANSWER: str = (
"Based on the available information, I cannot provide a complete answer to this question."
)
Expand Down Expand Up @@ -73,9 +96,14 @@ class BadResponseDetectionConfig(BaseModel):
)

# Shared config (for untrustworthiness and unhelpfulness checks)
tlm: Optional[TLM] = Field(
tlm_options: Optional[TLMOptions] = Field(
default=None,
description="Configuration options for the TLM model used for evaluation.",
)

codex_key: Optional[str] = Field(
default=None,
description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).",
description="Codex Access Key or API Key to use when querying TLM for untrustworthiness and unhelpfulness checks.",
)


Expand All @@ -95,6 +123,8 @@ def is_bad_response(
context: Optional[str] = None,
query: Optional[str] = None,
config: Union[BadResponseDetectionConfig, Dict[str, Any]] = _DEFAULT_CONFIG,
run_untrustworthy_check: Optional[bool] = True,
run_unhelpful_check: Optional[bool] = True,
) -> bool:
"""Run a series of checks to determine if a response is bad.

Expand All @@ -113,6 +143,8 @@ def is_bad_response(
context (str, optional): Optional context/documents used for answering. Required for untrustworthy check.
query (str, optional): Optional user question. Required for untrustworthy and unhelpful checks.
config (BadResponseDetectionConfig, optional): Optional, configuration parameters for validation checks. See [BadResponseDetectionConfig](#class-badresponsedetectionconfig) for details. If not provided, default values will be used.
run_untrustworthy_check (bool, optional): Optional flag to specify whether to run untrustworthy check. This check is run by default.
run_unhelpful_check (bool, optional): Optional flag to specify whether to run unhelpfulness check. This check is run by default.

Returns:
bool: `True` if any validation check fails, `False` if all pass.
Expand All @@ -130,28 +162,30 @@ def is_bad_response(
)
)

can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None
can_run_untrustworthy_check = query is not None and context is not None and run_untrustworthy_check
if can_run_untrustworthy_check:
# The if condition guarantees these are not None
validation_checks.append(
lambda: is_untrustworthy_response(
response=response,
context=cast(str, context),
query=cast(str, query),
tlm=cast(TLM, config.tlm),
tlm_options=config.tlm_options,
trustworthiness_threshold=config.trustworthiness_threshold,
format_prompt=config.format_prompt,
codex_key=config.codex_key,
)
)

can_run_unhelpful_check = query is not None and config.tlm is not None
can_run_unhelpful_check = query is not None and run_unhelpful_check
if can_run_unhelpful_check:
validation_checks.append(
lambda: is_unhelpful_response(
response=response,
query=cast(str, query),
tlm=cast(TLM, config.tlm),
tlm_options=config.tlm_options,
confidence_score_threshold=config.unhelpfulness_confidence_threshold,
codex_key=config.codex_key,
)
)

Expand Down Expand Up @@ -189,13 +223,76 @@ def is_fallback_response(
return bool(partial_ratio >= threshold)


def _create_codex_client(codex_key_arg: str | None) -> _Codex:
"""
Helper method to create a Codex client for proxying TLM requests.

Args:
codex_key_or_arg (str): A Codex API or Access key to use when querying TLM.

Returns:
_Codex: A Codex client to use to proxy TLM requests.
"""
if codex_key_arg is None:
try:
return client_from_access_key()
except MissingAuthKeyError:
pass
try:
return client_from_api_key()
except (MissingAuthKeyError, BadRequestError):
pass
raise MissingAuthError from None

try:
if is_access_key(codex_key_arg):
return client_from_access_key(codex_key_arg)

return client_from_api_key(codex_key_arg)
except (MissingAuthKeyError, BadRequestError):
raise MissingAuthError from None


def _try_tlm_score(
client: _Codex,
prompt: str,
response: str,
options: Optional[TLMOptions] = None,
constrain_outputs: Optional[list[str]] = None,
) -> TlmScoreResponse:
"""
Helper mtehod to try reaching the TLM scoring Codex endpoint, and catch any Authentication issues and raise our own Authentication Error.

Args:
client (_Codex): The (authenticated) Codex client to use.
prompt (str): The prompt to pass to tlm.score.
response (str): The response to pass to tlm.score.
options (TLMOptions): The TLMOptions to pass to the TLM.
constrain_outputs (list[str]): The constrain_outputs keyword argument to pass to tlm.score.

Notes:
We need the try-except here since when users authenticate via an access key, there is no eager check to see if they
are correctly authenticated (unlike when authenticating via an API key, which performs an immediate check to see
if the authentication is valid). This means that we could get AuthenticationErrors from the Codex client, that we
want to catch, and instead raise our own MissingAuthError.

Returns:
TLMScoreResponse: The TLMScoreResponse from TLM, or a MissingAuthError if the user is not correctly authenticated.
"""
try:
return client.tlm.score(prompt=prompt, response=response, options=options, constrain_outputs=constrain_outputs)
except AuthenticationError:
raise MissingAuthError from None


def is_untrustworthy_response(
response: str,
context: str,
query: str,
tlm: TLM,
tlm_options: Optional[TLMOptions] = None,
trustworthiness_threshold: float = _DEFAULT_TRUSTWORTHINESS_THRESHOLD,
format_prompt: Callable[[str, str], str] = default_format_prompt,
codex_key: Optional[str] = None,
) -> bool:
"""Check if a response is untrustworthy.

Expand All @@ -207,36 +304,30 @@ def is_untrustworthy_response(
response (str): The response to check from the assistant.
context (str): The context information available for answering the query.
query (str): The user's question or request.
tlm (TLM): The TLM model to use for evaluation.
tlm_options (TLMOptions): The options to pass to the TLM model used for evaluation.
trustworthiness_threshold (float): Score threshold (0.0-1.0) under which a response is considered untrustworthy.
Lower values allow less trustworthy responses. Default 0.5 means responses with scores less than 0.5 are considered untrustworthy.
format_prompt (Callable[[str, str], str]): Function that takes (query, context) and returns a formatted prompt string.
Users should provide the prompt formatting function for their RAG application here so that the response can
be evaluated using the same prompt that was used to generate the response.
codex_key (str): A Codex API or Access key to use when querying TLM.

Returns:
bool: `True` if the response is deemed untrustworthy by TLM, `False` otherwise.
"""
try:
from cleanlab_tlm import TLM # noqa: F401
except ImportError as e:
raise MissingDependencyError(
import_name=e.name or "cleanlab_tlm",
package_name="cleanlab-tlm",
package_url="https://github.com/cleanlab/cleanlab-tlm",
) from e

prompt = format_prompt(query, context)
result = tlm.get_trustworthiness_score(prompt, response)
client = _create_codex_client(codex_key)
result = _try_tlm_score(client=client, prompt=prompt, response=response, options=tlm_options)
score: float = result["trustworthiness_score"]
return score < trustworthiness_threshold


def is_unhelpful_response(
response: str,
query: str,
tlm: TLM,
tlm_options: Optional[TLMOptions] = None,
confidence_score_threshold: float = _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD,
codex_key: Optional[str] = None,
) -> bool:
"""Check if a response is unhelpful by asking [TLM](/tlm) to evaluate it.

Expand All @@ -248,23 +339,15 @@ def is_unhelpful_response(
Args:
response (str): The response to check.
query (str): User query that will be used to evaluate if the response is helpful.
tlm (TLM): The TLM model to use for evaluation.
tlm_options (TLMOptions): The options to pass to the TLM model used for evaluation.
confidence_score_threshold (float): Confidence threshold (0.0-1.0) above which a response is considered unhelpful.
E.g. if confidence_score_threshold is 0.5, then responses with scores higher than 0.5 are considered unhelpful.
codex_key (str): A Codex API or Access key to use when querying TLM.

Returns:
bool: `True` if TLM determines the response is unhelpful with sufficient confidence,
`False` otherwise.
"""
try:
from cleanlab_tlm import TLM # noqa: F401
except ImportError as e:
raise MissingDependencyError(
import_name=e.name or "cleanlab_tlm",
package_name="cleanlab-tlm",
package_url="https://github.com/cleanlab/cleanlab-tlm",
) from e

# IMPORTANT: The current implementation couples three things that must stay in sync:
# 1. The question phrasing ("is unhelpful?")
# 2. The expected_unhelpful_response ("Yes")
Expand Down Expand Up @@ -300,8 +383,13 @@ def is_unhelpful_response(
f"{question}"
)

output = tlm.get_trustworthiness_score(
prompt, response=expected_unhelpful_response, constrain_outputs=["Yes", "No"]
client = _create_codex_client(codex_key)
output = _try_tlm_score(
client=client,
prompt=prompt,
response=expected_unhelpful_response,
options=tlm_options,
constrain_outputs=["Yes", "No"],
)

# Current implementation assumes question is phrased to expect "Yes" for unhelpful responses
Expand Down
35 changes: 17 additions & 18 deletions src/cleanlab_codex/types/tlm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from typing import Any, Dict, Protocol, Sequence, Union, runtime_checkable


@runtime_checkable
class TLM(Protocol):
def get_trustworthiness_score(
self,
prompt: Union[str, Sequence[str]],
response: Union[str, Sequence[str]],
**kwargs: Any,
) -> Dict[str, Any]: ...

def prompt(
self,
prompt: Union[str, Sequence[str]],
/,
**kwargs: Any,
) -> Dict[str, Any]: ...
"""Types for Codex TLM endpoint."""

from codex.types.tlm_score_response import TlmScoreResponse as _TlmScoreResponse

from cleanlab_codex.internal.utils import generate_class_docstring


class TlmScoreResponse(_TlmScoreResponse): ...


TlmScoreResponse.__doc__ = f"""
Type representing a TLM score response in a Codex project. This is the complete data structure returned from the Codex API, including system-generated fields like ID and timestamps.

{generate_class_docstring(_TlmScoreResponse, name=TlmScoreResponse.__name__)}
"""

__all__ = ["TlmScoreResponse"]
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from tests.fixtures.client import mock_client_from_access_key, mock_client_from_api_key
from tests.fixtures.client import mock_client_from_access_key, mock_client_from_access_key_tlm, mock_client_from_api_key

__all__ = ["mock_client_from_access_key", "mock_client_from_api_key"]
__all__ = ["mock_client_from_access_key", "mock_client_from_api_key", "mock_client_from_access_key_tlm"]
8 changes: 8 additions & 0 deletions tests/fixtures/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ def mock_client_from_api_key() -> Generator[MagicMock, None, None]:
mock_client = MagicMock()
mock_init.return_value = mock_client
yield mock_client


@pytest.fixture
def mock_client_from_access_key_tlm() -> Generator[MagicMock, None, None]:
with patch("cleanlab_codex.response_validation.client_from_access_key") as mock_init:
mock_client = MagicMock()
mock_init.return_value = mock_client
yield mock_client
Loading
Loading