Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-language system prompts and BedrockChatAdapter implementation #576

Merged
merged 13 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,6 @@ lib/user-interface/react-app/src/graphql/subscriptions.ts
# js function
!lib/authentication/lambda/updateUserPoolClient/index.js
!lib/authentication/lambda/updateOidcSecret/index.js
/.project
/.pydevproject
/outputs.json
Original file line number Diff line number Diff line change
@@ -1,37 +1,57 @@
import os
import re
from enum import Enum
import genai_core.clients
from aws_lambda_powertools import Logger
from enum import Enum
from typing import Any, Dict, List
from genai_core.registry import registry
from genai_core.types import ChatbotMode
from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.conversation.base import ConversationChain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.conversational_retrieval.prompts import (
QA_PROMPT,
CONDENSE_QUESTION_PROMPT,
)
from typing import Dict, List, Any

from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory
from genai_core.types import ChatbotMode

from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_aws import ChatBedrockConverse
from adapters.shared.prompts.system_prompts import (
prompts,
locale,
) # Import prompts and language

logger = Logger()

# Setting programmatic log level
# logger.setLevel("DEBUG")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Setting programmatic log level
# logger.setLevel("DEBUG")

I would remove this because there is already a global log level setting here
https://github.com/aws-samples/aws-genai-llm-chatbot/blob/main/lib/shared/index.ts#L52

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to include this information in the developer's guide documentation?



class Mode(Enum):
CHAIN = "chain"


def get_guardrails() -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only applicable for bedrock. Why did you add it here?

if "BEDROCK_GUARDRAILS_ID" in os.environ:
logger.debug("Guardrails ID found in environment variables.")
return {
"guardrailIdentifier": os.environ["BEDROCK_GUARDRAILS_ID"],
"guardrailVersion": os.environ.get("BEDROCK_GUARDRAILS_VERSION", "DRAFT"),
}
logger.info("No guardrails ID found.")
return {}


class LLMStartHandler(BaseCallbackHandler):
prompts = []
usage = None
Expand Down Expand Up @@ -60,12 +80,12 @@ def on_llm_end(
"total_tokens": 0,
}
self.usage = {
"input_tokens": self.usage.get("input_tokens")
+ generation.message.usage_metadata.get("input_tokens"),
"output_tokens": self.usage.get("output_tokens")
+ generation.message.usage_metadata.get("output_tokens"),
"total_tokens": self.usage.get("total_tokens")
+ generation.message.usage_metadata.get("total_tokens"),
"input_tokens": self.usage.get("input_tokens", 0)
+ generation.message.usage_metadata.get("input_tokens", 0),
"output_tokens": self.usage.get("output_tokens", 0)
+ generation.message.usage_metadata.get("output_tokens", 0),
"total_tokens": self.usage.get("total_tokens", 0)
+ generation.message.usage_metadata.get("total_tokens", 0),
}


Expand Down Expand Up @@ -199,7 +219,7 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):
input={"input": user_prompt}, config=config
)
if "answer" in response:
answer = response.get("answer") # Rag flow
answer = response.get("answer") # RAG flow
else:
answer = response.content
except Exception as e:
Expand Down Expand Up @@ -239,11 +259,11 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):
# Used by Cloudwatch filters to generate a metric of token usage.
logger.info(
"Usage Metric",
# Each unique value of model id will create a
# new cloudwatch metric (each one has a cost)
model=self.model_id,
metric_type="token_usage",
value=self.callback_handler.usage.get("total_tokens"),
extra={
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the JSON format here would break the metric in the dashboard. Please undo
https://github.com/aws-samples/aws-genai-llm-chatbot/blob/main/lib/monitoring/index.ts#L289

"model": self.model_id,
"metric_type": "token_usage",
"value": self.callback_handler.usage.get("total_tokens"),
},
)

return {
Expand Down Expand Up @@ -342,3 +362,245 @@ def run(self, prompt, workspace_id=None, *args, **kwargs):
return self.run_with_chain(prompt, workspace_id)

raise ValueError(f"unknown mode {self._mode}")


class BedrockChatAdapter(ModelAdapter):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(self, model_id, *args, **kwargs):
self.model_id = model_id
logger.info(f"Initializing BedrockChatAdapter with model_id: {model_id}")
super().__init__(*args, **kwargs)

def get_qa_prompt(self):
# Fetch the QA prompt based on the current language
qa_system_prompt = prompts[locale]["qa_prompt"]
# Append the context placeholder if needed
qa_system_prompt_with_context = qa_system_prompt + "\n\n{context}"
logger.info(
f"Generating QA prompt template with: {qa_system_prompt_with_context}"
)

# Create the ChatPromptTemplate
chat_prompt_template = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt_with_context),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

# Trace the ChatPromptTemplate by logging its content
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}")

return chat_prompt_template

def get_prompt(self):
# Fetch the conversation prompt based on the current language
conversation_prompt = prompts[locale]["conversation_prompt"]
logger.info("Generating general conversation prompt template.")
chat_prompt_template = ChatPromptTemplate.from_messages(
[
("system", conversation_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
]
)
# Trace the ChatPromptTemplate by logging its content
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}")
return chat_prompt_template

def get_condense_question_prompt(self):
# Fetch the prompt based on the current language
condense_question_prompt = prompts[locale]["condense_question_prompt"]
logger.info("Generating condense question prompt template.")
chat_prompt_template = ChatPromptTemplate.from_messages(
[
("system", condense_question_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Trace the ChatPromptTemplate by logging its content
logger.debug(f"ChatPromptTemplate messages: {chat_prompt_template.messages}")
return chat_prompt_template

def get_llm(self, model_kwargs={}, extra={}):
bedrock = genai_core.clients.get_bedrock_client()
params = {}

# Collect temperature, topP, and maxTokens if available
temperature = model_kwargs.get("temperature")
top_p = model_kwargs.get("topP")
max_tokens = model_kwargs.get("maxTokens")

if temperature:
params["temperature"] = temperature
if top_p:
params["top_p"] = top_p
if max_tokens:
params["max_tokens"] = max_tokens

# Fetch guardrails if any
guardrails = get_guardrails()
if len(guardrails.keys()) > 0:
params["guardrails"] = guardrails

# Log all parameters in a single log entry, including full guardrails
logger.info(
f"Creating LLM chain for model {self.model_id}",
extra={
"model_kwargs": model_kwargs,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"guardrails": guardrails,
},
)

# Return ChatBedrockConverse instance with the collected params
return ChatBedrockConverse(
client=bedrock,
model=self.model_id,
disable_streaming=not model_kwargs.get("streaming", True)
or self.disable_streaming,
callbacks=[self.callback_handler],
**params,
**extra,
)


class BedrockChatNoStreamingAdapter(BedrockChatAdapter):
"""Some models do not support system streaming using the converse API"""

def __init__(self, *args, **kwargs):
logger.info(
"Initializing BedrockChatNoStreamingAdapter with disabled streaming."
)
super().__init__(disable_streaming=True, *args, **kwargs)


class BedrockChatNoSystemPromptAdapter(BedrockChatAdapter):
"""Some models do not support system and message history in the conversation API"""

def get_prompt(self):
# Fetch the conversation prompt and translated
# words based on the current language
conversation_prompt = prompts[locale]["conversation_prompt"]
question_word = prompts[locale]["question_word"]
assistant_word = prompts[locale]["assistant_word"]
logger.info("Generating no-system-prompt template for conversation.")

# Combine conversation prompt, chat history, and input into the template
template = f"""{conversation_prompt}

{question_word}: {{input}}

{assistant_word}:"""

# Create the PromptTemplateWithHistory instance
prompt_template = PromptTemplateWithHistory(
input_variables=["input", "chat_history"], template=template
)

# Log the content of PromptTemplateWithHistory before returning
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}")

return prompt_template

def get_condense_question_prompt(self):
# Fetch the prompt and translated words based on the current language
condense_question_prompt = prompts[locale]["condense_question_prompt"]
logger.info(f"condense_question_prompt: {condense_question_prompt}")

follow_up_input_word = prompts[locale]["follow_up_input_word"]
logger.info(f"follow_up_input_word: {follow_up_input_word}")

standalone_question_word = prompts[locale]["standalone_question_word"]
logger.info(f"standalone_question_word: {standalone_question_word}")

chat_history_word = prompts[locale]["chat_history_word"]
logger.info(f"chat_history_word: {chat_history_word}")

logger.info("Generating no-system-prompt template for condensing question.")

# Combine the prompt with placeholders
template = f"""{condense_question_prompt}
{chat_history_word}: {{chat_history}}
{follow_up_input_word}: {{input}}
{standalone_question_word}:"""
# Log the content of template
logger.info(f"get_condense_question_prompt: Template content: {template}")
# Create the PromptTemplateWithHistory instance
prompt_template = PromptTemplateWithHistory(
input_variables=["input", "chat_history"], template=template
)

# Log the content of PromptTemplateWithHistory before returning
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}")

return prompt_template

def get_qa_prompt(self):
# Fetch the QA prompt and translated words based on the current language
qa_system_prompt = prompts[locale]["qa_prompt"]
question_word = prompts[locale]["question_word"]
helpful_answer_word = prompts[locale]["helpful_answer_word"]
logger.info("Generating no-system-prompt QA template.")

# Combine the prompt with placeholders
template = f"""{qa_system_prompt}

{{context}}

{question_word}: {{input}}
{helpful_answer_word}:"""

# Create the PromptTemplateWithHistory instance
prompt_template = PromptTemplateWithHistory(
input_variables=["input", "context"], template=template
)

# Log the content of PromptTemplateWithHistory before returning
logger.debug(f"PromptTemplateWithHistory template: {prompt_template.template}")

return prompt_template


class BedrockChatNoStreamingNoSystemPromptAdapter(BedrockChatNoSystemPromptAdapter):
"""Some models do not support system streaming using the converse API"""

def __init__(self, *args, **kwargs):
super().__init__(disable_streaming=True, *args, **kwargs)


class PromptTemplateWithHistory(PromptTemplate):
def format(self, **kwargs: Any) -> str:
chat_history = kwargs.get("chat_history", "")
if isinstance(chat_history, List):
# RunnableWithMessageHistory is provided a list of BaseMessage as a history
# Since this model does not support history, we format the common prompt to
# list the history
chat_history_str = ""
for message in chat_history:
if isinstance(message, BaseMessage):
prefix = ""
if isinstance(message, AIMessage):
prefix = "AI: "
elif isinstance(message, HumanMessage):
prefix = "Human: "
chat_history_str += prefix + message.content + "\n"
kwargs["chat_history"] = chat_history_str
return super().format(**kwargs)


# Register the adapters
registry.register(r"^bedrock.ai21.jamba*", BedrockChatAdapter)
registry.register(r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter)
registry.register(r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter)
registry.register(r"^bedrock\.cohere\.command-r.*", BedrockChatAdapter)
registry.register(r"^bedrock.anthropic.claude*", BedrockChatAdapter)
registry.register(r"^bedrock.meta.llama*", BedrockChatAdapter)
registry.register(r"^bedrock.mistral.mistral-large*", BedrockChatAdapter)
registry.register(r"^bedrock.mistral.mistral-small*", BedrockChatAdapter)
registry.register(r"^bedrock.mistral.mistral-7b-*", BedrockChatNoSystemPromptAdapter)
registry.register(r"^bedrock.mistral.mixtral-*", BedrockChatNoSystemPromptAdapter)
registry.register(r"^bedrock.amazon.titan-t*", BedrockChatNoSystemPromptAdapter)
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# flake8: noqa
from .base import *
from adapters.bedrock.base import *
Loading