-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-language system prompts and BedrockChatAdapter implementation #576
Changes from 6 commits
256279d
a6cb4e0
a101eb8
121c7df
dd91d3f
2a56637
75204eb
f227999
48a4d41
2f9ef99
6369e8d
3572f1e
562e1e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,57 @@ | ||
import os | ||
import re | ||
from enum import Enum | ||
import genai_core.clients | ||
from aws_lambda_powertools import Logger | ||
from enum import Enum | ||
from typing import Any, Dict, List | ||
from genai_core.registry import registry | ||
from genai_core.types import ChatbotMode | ||
from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory | ||
from langchain.callbacks.base import BaseCallbackHandler | ||
from langchain.chains.conversation.base import ConversationChain | ||
from langchain.chains import ConversationalRetrievalChain | ||
from langchain.chains.retrieval import create_retrieval_chain | ||
from langchain.chains.history_aware_retriever import create_history_aware_retriever | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.memory import ConversationBufferMemory | ||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.chains.conversational_retrieval.prompts import ( | ||
QA_PROMPT, | ||
CONDENSE_QUESTION_PROMPT, | ||
) | ||
from typing import Dict, List, Any | ||
|
||
from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory | ||
from genai_core.types import ChatbotMode | ||
|
||
from langchain_core.runnables.history import RunnableWithMessageHistory | ||
from langchain_core.outputs import LLMResult, ChatGeneration | ||
from langchain_core.messages import BaseMessage | ||
from langchain_core.messages.ai import AIMessage, AIMessageChunk | ||
from langchain_core.messages.human import HumanMessage | ||
from langchain_aws import ChatBedrockConverse | ||
from adapters.shared.prompts.system_prompts import ( | ||
prompts, | ||
locale, | ||
) # Import prompts and language | ||
|
||
logger = Logger() | ||
|
||
# Setting programmatic log level | ||
# logger.setLevel("DEBUG") | ||
|
||
|
||
class Mode(Enum): | ||
CHAIN = "chain" | ||
|
||
|
||
def get_guardrails() -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only applicable for bedrock. Why did you add it here? |
||
if "BEDROCK_GUARDRAILS_ID" in os.environ: | ||
logger.debug("Guardrails ID found in environment variables.") | ||
return { | ||
"guardrailIdentifier": os.environ["BEDROCK_GUARDRAILS_ID"], | ||
"guardrailVersion": os.environ.get("BEDROCK_GUARDRAILS_VERSION", "DRAFT"), | ||
} | ||
logger.info("No guardrails ID found.") | ||
return {} | ||
|
||
|
||
class LLMStartHandler(BaseCallbackHandler): | ||
prompts = [] | ||
usage = None | ||
|
@@ -60,12 +80,12 @@ def on_llm_end( | |
"total_tokens": 0, | ||
} | ||
self.usage = { | ||
"input_tokens": self.usage.get("input_tokens") | ||
+ generation.message.usage_metadata.get("input_tokens"), | ||
"output_tokens": self.usage.get("output_tokens") | ||
+ generation.message.usage_metadata.get("output_tokens"), | ||
"total_tokens": self.usage.get("total_tokens") | ||
+ generation.message.usage_metadata.get("total_tokens"), | ||
"input_tokens": self.usage.get("input_tokens", 0) | ||
+ generation.message.usage_metadata.get("input_tokens", 0), | ||
"output_tokens": self.usage.get("output_tokens", 0) | ||
+ generation.message.usage_metadata.get("output_tokens", 0), | ||
"total_tokens": self.usage.get("total_tokens", 0) | ||
+ generation.message.usage_metadata.get("total_tokens", 0), | ||
} | ||
|
||
|
||
|
@@ -199,7 +219,7 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None): | |
input={"input": user_prompt}, config=config | ||
) | ||
if "answer" in response: | ||
answer = response.get("answer") # Rag flow | ||
answer = response.get("answer") # RAG flow | ||
else: | ||
answer = response.content | ||
except Exception as e: | ||
|
@@ -239,11 +259,11 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None): | |
# Used by Cloudwatch filters to generate a metric of token usage. | ||
logger.info( | ||
"Usage Metric", | ||
# Each unique value of model id will create a | ||
# new cloudwatch metric (each one has a cost) | ||
model=self.model_id, | ||
metric_type="token_usage", | ||
value=self.callback_handler.usage.get("total_tokens"), | ||
extra={ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing the JSON format here would break the metric in the dashboard. Please undo |
||
"model": self.model_id, | ||
"metric_type": "token_usage", | ||
"value": self.callback_handler.usage.get("total_tokens"), | ||
}, | ||
) | ||
|
||
return { | ||
|
@@ -342,3 +362,245 @@ def run(self, prompt, workspace_id=None, *args, **kwargs): | |
return self.run_with_chain(prompt, workspace_id) | ||
|
||
raise ValueError(f"unknown mode {self._mode}") | ||
|
||
|
||
class BedrockChatAdapter(ModelAdapter): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a copy of https://github.com/aws-samples/aws-genai-llm-chatbot/blob/main/lib/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base.py It would revert this change in this file. |
||
def __init__(self, model_id, *args, **kwargs): | ||
self.model_id = model_id | ||
logger.info(f"Initializing BedrockChatAdapter with model_id: {model_id}") | ||
super().__init__(*args, **kwargs) | ||
|
||
def get_qa_prompt(self): | ||
# Fetch the QA prompt based on the current language | ||
qa_system_prompt = prompts[locale]["qa_prompt"] | ||
# Append the context placeholder if needed | ||
qa_system_prompt_with_context = qa_system_prompt + "\n\n{context}" | ||
logger.info( | ||
f"Generating QA prompt template with: {qa_system_prompt_with_context}" | ||
) | ||
|
||
# Create the ChatPromptTemplate | ||
chat_prompt_template = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", qa_system_prompt_with_context), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
|
||
# Trace the ChatPromptTemplate by logging its content | ||
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}") | ||
|
||
return chat_prompt_template | ||
|
||
def get_prompt(self): | ||
# Fetch the conversation prompt based on the current language | ||
conversation_prompt = prompts[locale]["conversation_prompt"] | ||
logger.info("Generating general conversation prompt template.") | ||
chat_prompt_template = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", conversation_prompt), | ||
MessagesPlaceholder(variable_name="chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
# Trace the ChatPromptTemplate by logging its content | ||
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}") | ||
return chat_prompt_template | ||
|
||
def get_condense_question_prompt(self): | ||
# Fetch the prompt based on the current language | ||
condense_question_prompt = prompts[locale]["condense_question_prompt"] | ||
logger.info("Generating condense question prompt template.") | ||
chat_prompt_template = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", condense_question_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
# Trace the ChatPromptTemplate by logging its content | ||
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}") | ||
return chat_prompt_template | ||
|
||
def get_llm(self, model_kwargs={}, extra={}): | ||
bedrock = genai_core.clients.get_bedrock_client() | ||
params = {} | ||
|
||
# Collect temperature, topP, and maxTokens if available | ||
temperature = model_kwargs.get("temperature") | ||
top_p = model_kwargs.get("topP") | ||
max_tokens = model_kwargs.get("maxTokens") | ||
|
||
if temperature: | ||
params["temperature"] = temperature | ||
if top_p: | ||
params["top_p"] = top_p | ||
if max_tokens: | ||
params["max_tokens"] = max_tokens | ||
|
||
# Fetch guardrails if any | ||
guardrails = get_guardrails() | ||
if len(guardrails.keys()) > 0: | ||
params["guardrails"] = guardrails | ||
|
||
# Log all parameters in a single log entry, including full guardrails | ||
logger.info( | ||
f"Creating LLM chain for model {self.model_id}", | ||
extra={ | ||
"model_kwargs": model_kwargs, | ||
"temperature": temperature, | ||
"top_p": top_p, | ||
"max_tokens": max_tokens, | ||
"guardrails": guardrails, | ||
}, | ||
) | ||
|
||
# Return ChatBedrockConverse instance with the collected params | ||
return ChatBedrockConverse( | ||
client=bedrock, | ||
model=self.model_id, | ||
disable_streaming=not model_kwargs.get("streaming", True) | ||
or self.disable_streaming, | ||
callbacks=[self.callback_handler], | ||
**params, | ||
**extra, | ||
) | ||
|
||
|
||
class BedrockChatNoStreamingAdapter(BedrockChatAdapter): | ||
"""Some models do not support system streaming using the converse API""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
logger.info( | ||
"Initializing BedrockChatNoStreamingAdapter with disabled streaming." | ||
) | ||
super().__init__(disable_streaming=True, *args, **kwargs) | ||
|
||
|
||
class BedrockChatNoSystemPromptAdapter(BedrockChatAdapter): | ||
"""Some models do not support system and message history in the conversation API""" | ||
|
||
def get_prompt(self): | ||
# Fetch the conversation prompt and translated | ||
# words based on the current language | ||
conversation_prompt = prompts[locale]["conversation_prompt"] | ||
question_word = prompts[locale]["question_word"] | ||
assistant_word = prompts[locale]["assistant_word"] | ||
logger.info("Generating no-system-prompt template for conversation.") | ||
|
||
# Combine conversation prompt, chat history, and input into the template | ||
template = f"""{conversation_prompt} | ||
|
||
{question_word}: {{input}} | ||
|
||
{assistant_word}:""" | ||
|
||
# Create the PromptTemplateWithHistory instance | ||
prompt_template = PromptTemplateWithHistory( | ||
input_variables=["input", "chat_history"], template=template | ||
) | ||
|
||
# Log the content of PromptTemplateWithHistory before returning | ||
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}") | ||
|
||
return prompt_template | ||
|
||
def get_condense_question_prompt(self): | ||
# Fetch the prompt and translated words based on the current language | ||
condense_question_prompt = prompts[locale]["condense_question_prompt"] | ||
logger.info(f"condense_question_prompt: {condense_question_prompt}") | ||
|
||
follow_up_input_word = prompts[locale]["follow_up_input_word"] | ||
logger.info(f"follow_up_input_word: {follow_up_input_word}") | ||
|
||
standalone_question_word = prompts[locale]["standalone_question_word"] | ||
logger.info(f"standalone_question_word: {standalone_question_word}") | ||
|
||
chat_history_word = prompts[locale]["chat_history_word"] | ||
logger.info(f"chat_history_word: {chat_history_word}") | ||
|
||
logger.info("Generating no-system-prompt template for condensing question.") | ||
|
||
# Combine the prompt with placeholders | ||
template = f"""{condense_question_prompt} | ||
{chat_history_word}: {{chat_history}} | ||
{follow_up_input_word}: {{input}} | ||
{standalone_question_word}:""" | ||
# Log the content of template | ||
logger.info(f"get_condense_question_prompt: Template content: {template}") | ||
# Create the PromptTemplateWithHistory instance | ||
prompt_template = PromptTemplateWithHistory( | ||
input_variables=["input", "chat_history"], template=template | ||
) | ||
|
||
# Log the content of PromptTemplateWithHistory before returning | ||
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}") | ||
|
||
return prompt_template | ||
|
||
def get_qa_prompt(self): | ||
# Fetch the QA prompt and translated words based on the current language | ||
qa_system_prompt = prompts[locale]["qa_prompt"] | ||
question_word = prompts[locale]["question_word"] | ||
helpful_answer_word = prompts[locale]["helpful_answer_word"] | ||
logger.info("Generating no-system-prompt QA template.") | ||
|
||
# Combine the prompt with placeholders | ||
template = f"""{qa_system_prompt} | ||
|
||
{{context}} | ||
|
||
{question_word}: {{input}} | ||
{helpful_answer_word}:""" | ||
|
||
# Create the PromptTemplateWithHistory instance | ||
prompt_template = PromptTemplateWithHistory( | ||
input_variables=["input", "context"], template=template | ||
) | ||
|
||
# Log the content of PromptTemplateWithHistory before returning | ||
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}") | ||
|
||
return prompt_template | ||
|
||
|
||
class BedrockChatNoStreamingNoSystemPromptAdapter(BedrockChatNoSystemPromptAdapter): | ||
"""Some models do not support system streaming using the converse API""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(disable_streaming=True, *args, **kwargs) | ||
|
||
|
||
class PromptTemplateWithHistory(PromptTemplate): | ||
def format(self, **kwargs: Any) -> str: | ||
chat_history = kwargs.get("chat_history", "") | ||
if isinstance(chat_history, List): | ||
# RunnableWithMessageHistory is provided a list of BaseMessage as a history | ||
# Since this model does not support history, we format the common prompt to | ||
# list the history | ||
chat_history_str = "" | ||
for message in chat_history: | ||
if isinstance(message, BaseMessage): | ||
prefix = "" | ||
if isinstance(message, AIMessage): | ||
prefix = "AI: " | ||
elif isinstance(message, HumanMessage): | ||
prefix = "Human: " | ||
chat_history_str += prefix + message.content + "\n" | ||
kwargs["chat_history"] = chat_history_str | ||
return super().format(**kwargs) | ||
|
||
|
||
# Register the adapters | ||
registry.register(r"^bedrock.ai21.jamba*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter) | ||
registry.register(r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter) | ||
registry.register(r"^bedrock\.cohere\.command-r.*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.anthropic.claude*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.meta.llama*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.mistral.mistral-large*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.mistral.mistral-small*", BedrockChatAdapter) | ||
registry.register(r"^bedrock.mistral.mistral-7b-*", BedrockChatNoSystemPromptAdapter) | ||
registry.register(r"^bedrock.mistral.mixtral-*", BedrockChatNoSystemPromptAdapter) | ||
registry.register(r"^bedrock.amazon.titan-t*", BedrockChatNoSystemPromptAdapter) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# flake8: noqa | ||
from .base import * | ||
from adapters.bedrock.base import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove this because there is already a global log level setting here
https://github.com/aws-samples/aws-genai-llm-chatbot/blob/main/lib/shared/index.ts#L52
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to include this information in the developer's guide documentation?