Skip to content

Commit

Permalink
feat: Add GroqChatTarget (Azure#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsdlm committed Feb 12, 2025
1 parent 6d07f5b commit d1355cf
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
from pyrit.prompt_target.crucible_target import CrucibleTarget
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
from pyrit.prompt_target.groq_chat_target import GroqChatTarget
from pyrit.prompt_target.http_target.http_target import HTTPTarget
from pyrit.prompt_target.http_target.http_target_callback_functions import (
get_http_target_json_response_callback_function,
Expand All @@ -34,6 +35,7 @@
"CrucibleTarget",
"GandalfLevel",
"GandalfTarget",
"GroqChatTarget",
"get_http_target_json_response_callback_function",
"get_http_target_regex_matching_callback_function",
"HTTPTarget",
Expand Down
56 changes: 56 additions & 0 deletions pyrit/prompt_target/groq_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget
from pyrit.models import ChatMessageListDictContent
from pyrit.exceptions import PyritException, EmptyResponseException
from openai.types.chat import ChatCompletion
from pyrit.exceptions import pyrit_target_retry

logger = logging.getLogger(__name__)

class GroqChatTarget(OpenAIChatTarget):

@pyrit_target_retry
async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str:
"""
Completes asynchronous chat request.
Sends a chat message to the OpenAI chat model and retrieves the generated response.
Args:
messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content.
is_json_response (bool): Boolean indicating if the response should be in JSON format.
Returns:
str: The generated response message.
"""
response: ChatCompletion = await self._async_client.chat.completions.create(
model=self._deployment_name,
max_completion_tokens=self._max_completion_tokens,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
n=1,
stream=False,
seed=self._seed,
messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore
response_format={"type": "json_object"} if is_json_response else None,
)
finish_reason = response.choices[0].finish_reason
extracted_response: str = ""
# finish_reason="stop" means API returned complete message and
# "length" means API returned incomplete message due to max_tokens limit.
if finish_reason in ["stop", "length"]:
extracted_response = self._parse_chat_completion(response)
# Handle empty response
if not extracted_response:
logger.log(logging.ERROR, "The chat returned an empty response.")
raise EmptyResponseException(message="The chat returned an empty response.")
else:
raise PyritException(message=f"Unknown finish_reason {finish_reason}")

return extracted_response

0 comments on commit d1355cf

Please sign in to comment.