Skip to content

Commit

Permalink
feat: support offline job in china region
Browse files Browse the repository at this point in the history
  • Loading branch information
IcyKallen committed Feb 28, 2025
1 parent df488f6 commit 71a16c6
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 188 deletions.
55 changes: 30 additions & 25 deletions source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac


private createKnowledgeBaseJob(props: any) {
const deployRegion = props.config.deployRegion;
const connection = new glue.Connection(this, "GlueJobConnection", {
type: glue.ConnectionType.NETWORK,
subnet: props.sharedConstructOutputs.vpc.privateSubnets[0],
Expand All @@ -159,10 +160,8 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
notificationLambda.addToRolePolicy(this.iamHelper.logStatement);
notificationLambda.addToRolePolicy(this.dynamodbStatement);

// If this.region is cn-north-1 or cn-northwest-1, use the glue-job-script-cn.py
const glueJobScript = "glue-job-script.py";

// Assemble the extra python files list using _S3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl") and _S3Bucket.s3UrlForObject("nougat_ocr-0.1.17-py3-none-any.whl") and convert to string
// Assemble the extra python files list using _S3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl")
const extraPythonFilesList = [
this.glueLibS3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl"),
].join(",");
Expand Down Expand Up @@ -202,36 +201,42 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
glueRole.addToPolicy(this.dynamodbStatement);
glueRole.addToPolicy(this.iamHelper.dynamodbStatement);

const glueJobDefaultArguments: { [key: string]: string } = {
"--AOS_ENDPOINT": this.aosDomainEndpoint,
"--REGION": deployRegion,
"--ETL_MODEL_ENDPOINT": props.modelConstructOutputs.defaultKnowledgeBaseModelName,
"--RES_BUCKET": this.glueResultBucket.bucketName,
"--ETL_OBJECT_TABLE": this.etlObjTableName || "-",
"--PORTAL_BUCKET": this.uiPortalBucketName,
"--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName,
"--additional-python-modules":
"langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1,tiktoken==0.8.0",
// Add multiple extra python files
"--extra-py-files": extraPythonFilesList,
}

// Set China-specific PyPI mirror for China regions
if (deployRegion === "cn-north-1" || deployRegion === "cn-northwest-1") {
glueJobDefaultArguments["--python-modules-installer-option"] = "-i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple";
}

// Create glue job to process files specified in s3 bucket and prefix
const glueJob = new glue.Job(this, "PythonShellJob", {
executable: glue.JobExecutable.pythonShell({
glueVersion: glue.GlueVersion.V3_0,
pythonVersion: glue.PythonVersion.THREE_NINE,
const glueJob = new glue.Job(this, "PythonEtlJob", {
executable: glue.JobExecutable.pythonEtl({
glueVersion: glue.GlueVersion.V4_0,
pythonVersion: glue.PythonVersion.THREE,
script: glue.Code.fromAsset(
join(__dirname, "../../../lambda/job", glueJobScript),
join(__dirname, "../../../lambda/job/glue-job-script.py"),
),
}),
// Worker Type is not supported for Job Command pythonshell and Both workerType and workerCount must be set
// workerType: glue.WorkerType.G_2X,
// workerCount: 2,
workerType: glue.WorkerType.G_1X,
workerCount: 2,
maxConcurrentRuns: 200,
maxRetries: 1,
connections: [connection],
maxCapacity: 1,
role: glueRole,
defaultArguments: {
"--AOS_ENDPOINT": this.aosDomainEndpoint,
"--REGION": process.env.CDK_DEFAULT_REGION || "-",
"--ETL_MODEL_ENDPOINT": props.modelConstructOutputs.defaultKnowledgeBaseModelName,
"--RES_BUCKET": this.glueResultBucket.bucketName,
"--ETL_OBJECT_TABLE": this.etlObjTableName || "-",
"--PORTAL_BUCKET": this.uiPortalBucketName,
"--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName,
"--additional-python-modules":
"langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.9.1,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1,tiktoken==0.8.0",
// Add multiple extra python files
"--extra-py-files": extraPythonFilesList
},
defaultArguments: glueJobDefaultArguments,
});

// Create SNS topic and subscription to notify when glue job is completed
Expand Down Expand Up @@ -308,7 +313,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--TABLE_ITEM_ID.$": "$.tableItemId",
"--QA_ENHANCEMENT.$": "$.qaEnhance",
"--REGION": process.env.CDK_DEFAULT_REGION || "-",
"--BEDROCK_REGION": props.config.chat.bedrockRegion,
"--BEDROCK_REGION": props.config.chat.bedrockRegion || "-",
"--MODEL_TABLE": props.sharedConstructOutputs.modelTable.tableName,
"--RES_BUCKET": this.glueResultBucket.bucketName,
"--S3_BUCKET.$": "$.s3Bucket",
Expand Down
2 changes: 1 addition & 1 deletion source/infrastructure/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"clobber": "npx projen clobber",
"compile": "npx projen compile",
"default": "npx projen default",
"deploy": "npx cdk deploy --all",
"deploy": "npx cdk deploy --all --require-approval never",
"destroy": "npx projen destroy",
"diff": "npx projen diff",
"eject": "npx projen eject",
Expand Down
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
20 changes: 15 additions & 5 deletions source/lambda/job/dep/llm_bot_dep/enhance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Dict, List

import boto3
import nltk
import openai
from langchain.docstore.document import Document

Expand Down Expand Up @@ -127,7 +126,9 @@ def EnhanceWithClaude(
answer = line
qa_content = f"{question}\n{answer}"
enhanced_prompt_list.append(
Document(page_content=qa_content, metadata=document.metadata)
Document(
page_content=qa_content, metadata=document.metadata
)
)
question = ""
answer = ""
Expand All @@ -137,7 +138,11 @@ def EnhanceWithClaude(
return enhanced_prompt_list

def EnhanceWithOpenAI(
self, prompt: str, solution_title: str, document: Document, zh: bool = True
self,
prompt: str,
solution_title: str,
document: Document,
zh: bool = True,
) -> List[Dict[str, str]]:
"""
Enhances a given prompt with additional information and performs a chat completion using OpenAI's GPT-3.5 Turbo model.
Expand Down Expand Up @@ -165,7 +170,10 @@ def EnhanceWithOpenAI(
# error and retry handling for openai api due to request cap limit
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", messages=messages, temperature=0, max_tokens=2048
model="gpt-3.5-turbo",
messages=messages,
temperature=0,
max_tokens=2048,
)
except Exception as e:
logger.error("OpenAI API request failed: {}".format(e))
Expand Down Expand Up @@ -208,7 +216,9 @@ def SplitDocumentByTokenNum(
- List[Document]: A list of documents, each containing a slice of the original document.
"""
# Get the token number of input paragraph
tokens = nltk.word_tokenize(document.page_content)
# tokens = nltk.word_tokenize(document.page_content)
# TODO: Currently Disable tokenization for now
tokens = []
metadata = document.metadata
if "content_type" in metadata:
metadata["content_type"] = "qa"
Expand Down
30 changes: 16 additions & 14 deletions source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BedrockEmbeddings,
SagemakerEndpointEmbeddings,
)
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import (
EmbeddingsContentHandler,
)
Expand All @@ -21,14 +22,11 @@
from langchain_core.messages import BaseMessage
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_community.embeddings.openai import OpenAIEmbeddings

logger = logging.getLogger()
logger.setLevel(logging.INFO)
session = boto3.session.Session()
secret_manager_client = session.client(
service_name="secretsmanager"
)
secret_manager_client = session.client(service_name="secretsmanager")


def get_model_details(group_name: str, chatbot_id: str, table_name: str):
Expand All @@ -48,14 +46,13 @@ def get_model_details(group_name: str, chatbot_id: str, table_name: str):

try:
response = table.get_item(
Key={
"groupName": group_name,
"modelId": model_id
}
Key={"groupName": group_name, "modelId": model_id}
)

if "Item" not in response:
raise Exception(f"No model found for group {group_name} and model ID {model_id}")
raise Exception(
f"No model found for group {group_name} and model ID {model_id}"
)

return response["Item"]
except Exception as e:
Expand Down Expand Up @@ -512,14 +509,15 @@ def SagemakerEndpointVectorOrCross(
)
return genericModel(prompt=prompt, stop=stop, **kwargs)


def getCustomEmbeddings(
endpoint_name: str,
region_name: str,
bedrock_region: str,
model_type: str,
group_name: str,
chatbot_id: str,
model_table: str
model_table: str,
) -> SagemakerEndpointEmbeddings:
embeddings = None
model_details = get_model_details(group_name, chatbot_id, model_table)
Expand All @@ -531,8 +529,10 @@ def getCustomEmbeddings(
if model_provider not in ["Bedrock API", "OpenAI API"]:
# Use local models
client = boto3.client("sagemaker-runtime", region_name=region_name)
bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region)
if model_type == "bedrock":
bedrock_client = boto3.client(
"bedrock-runtime", region_name=bedrock_region
)
content_handler = BedrockEmbeddings()
embeddings = BedrockEmbeddings(
client=bedrock_client,
Expand Down Expand Up @@ -567,15 +567,17 @@ def getCustomEmbeddings(
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(api_key_arn),
base_url=base_url
base_url=base_url,
)
elif model_provider == "OpenAI API":
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(api_key_arn),
base_url=base_url
base_url=base_url,
)
else:
raise ValueError(f"Unsupported API inference provider: {model_provider}")
raise ValueError(
f"Unsupported API inference provider: {model_provider}"
)

return embeddings
Loading

0 comments on commit 71a16c6

Please sign in to comment.