Skip to content

Commit

Permalink
feat: support siliconflow
Browse files Browse the repository at this point in the history
  • Loading branch information
IcyKallen committed Feb 28, 2025
1 parent b0590b7 commit 4f7a531
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 56 deletions.
2 changes: 1 addition & 1 deletion source/infrastructure/lib/model/model-construct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs
executionRole.addToPolicy(this.modelIamHelper.stsStatement);
executionRole.addToPolicy(this.modelIamHelper.ecrStatement);
executionRole.addToPolicy(this.modelIamHelper.llmStatement);

executionRole.addToPolicy(this.modelIamHelper.secretsManagerStatement);
this.modelExecutionRole = executionRole;
}

Expand Down
31 changes: 19 additions & 12 deletions source/infrastructure/lib/shared/iam-helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import { Aws } from "aws-cdk-lib";
import { Construct } from "constructs";
import {PolicyStatement, Effect} from "aws-cdk-lib/aws-iam";
import { PolicyStatement, Effect } from "aws-cdk-lib/aws-iam";

export class IAMHelper extends Construct {
public logStatement: PolicyStatement;
Expand All @@ -32,6 +32,7 @@ export class IAMHelper extends Construct {
public cfnStatement: PolicyStatement;
public serviceQuotaStatement: PolicyStatement;
public sagemakerModelManagementStatement: PolicyStatement;
public secretsManagerStatement: PolicyStatement;

public createPolicyStatement(actions: string[], resources: string[]) {
return new PolicyStatement({
Expand All @@ -51,7 +52,7 @@ export class IAMHelper extends Construct {
"logs:CreateLogStream",
"logs:PutLogEvents",
],
[ `arn:${Aws.PARTITION}:logs:${Aws.REGION}:${Aws.ACCOUNT_ID}:log-group:*:*` ],
[`arn:${Aws.PARTITION}:logs:${Aws.REGION}:${Aws.ACCOUNT_ID}:log-group:*:*`],
);
this.s3Statement = this.createPolicyStatement(
[
Expand All @@ -60,14 +61,14 @@ export class IAMHelper extends Construct {
"s3:PutObject",
"s3:GetObject",
],
[ "*" ],
["*"],
);
this.glueStatement = this.createPolicyStatement(
[
"glue:StartJobRun",
"glue:GetJobRun*",
],
[ "*" ],
["*"],
);
this.endpointStatement = this.createPolicyStatement(
[
Expand Down Expand Up @@ -106,7 +107,7 @@ export class IAMHelper extends Construct {
"iam:PutRolePolicy",
"iam:Get*",
],
[ "*" ],
["*"],
);
this.ecrStatement = this.createPolicyStatement(
[
Expand All @@ -123,7 +124,7 @@ export class IAMHelper extends Construct {
"ecr:CompleteLayerUpload",
"ecr:PutImage",
],
[ "*" ],
["*"],
);
this.llmStatement = this.createPolicyStatement(
[
Expand All @@ -135,19 +136,19 @@ export class IAMHelper extends Construct {
"cloudwatch:DeleteAlarms",
"cloudwatch:DescribeAlarms",
],
[ "*" ],
["*"],
);
this.cognitoStatement = this.createPolicyStatement(
[
"cognito-idp:ListGroups",
],
[ "*" ],
["*"],
);
this.bedrockStatement = this.createPolicyStatement(
[
"bedrock:*",
],
[ "*" ],
["*"],
);
this.esStatement = this.createPolicyStatement(
[
Expand All @@ -157,13 +158,13 @@ export class IAMHelper extends Construct {
"es:ESHttpHead",
"es:DescribeDomain"
],
[ "*" ],
["*"],
);
this.secretStatement = this.createPolicyStatement(
[
"secretsmanager:GetSecretValue",
],
[ "*" ],
["*"],
)
this.codePipelineStatement = this.createPolicyStatement(
[
Expand All @@ -175,7 +176,7 @@ export class IAMHelper extends Construct {
"codepipeline:StopPipelineExecution",
"codepipeline:GetPipelineExecution",
],
[ "*" ],
["*"],
);
this.cfnStatement = this.createPolicyStatement(
[
Expand Down Expand Up @@ -245,5 +246,11 @@ export class IAMHelper extends Construct {
],
["*"],
);
this.secretsManagerStatement = this.createPolicyStatement(
[
"secretsmanager:GetSecretValue",
],
["*"],
);
}
}
5 changes: 5 additions & 0 deletions source/lambda/job/glue-job-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def process_file(self, key: str, file_type: str, file_content: str):
None
"""
create_time = str(datetime.now(timezone.utc))
# TODO: make it configurable in frontend
kwargs = {
"bucket": self.bucket,
"key": key,
Expand All @@ -180,6 +181,10 @@ def process_file(self, key: str, file_type: str, file_content: str):
"create_time": create_time,
"portal_bucket_name": portal_bucket_name,
"document_language": document_language,
"model_provider": "siliconflow",
"model_id": "Qwen/Qwen2-VL-72B-Instruct",
"api_secret_name": "siliconflow-api-key",
"api_url": "https://api.siliconflow.cn/v1/chat/completions",
}

input_body = {
Expand Down
172 changes: 136 additions & 36 deletions source/model/etl/code/figure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,85 @@
import boto3
from botocore.exceptions import ClientError

CHART_UNDERSTAND_PROMPT = """您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
1. 找到图片中的图表。
2. 仔细观察图表,了解其中包含的结构和数据。
3. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。
4. 按照以下指南将图表数据转换成 Markdown 表格格式:
- 使用 | 字符分隔列
- 使用 --- 行表示标题行
- 确保表格格式正确并对齐
- 对于不确定的数字,请根据图片估算。
5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。
6. 在 <output></output>xml 标签中仅返回 Markdown,不含其他文本。
<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
"""

IMAGE_DESCRIPTION_PROMPT = """
你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行:
1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。
2. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。
3. 将你的描述写在<output></output>标签之间。
<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
"""


class figureUnderstand:
def __init__(self):
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime")
def __init__(
self,
model_provider="bedrock",
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
api_url=None,
api_key=None,
):
"""
Initialize the figureUnderstand class with configurable model provider.
Args:
model_provider (str): The model provider to use ('bedrock', 'openai', 'siliconflow', etc.)
model_id (str): The model ID to use
api_url (str, optional): The API URL for non-AWS providers
api_key (str, optional): The API key for non-AWS providers
"""
self.model_provider = model_provider.lower()
self.model_id = model_id
self.api_url = api_url
self.api_key = api_key

# Initialize appropriate client based on provider
if self.model_provider == "bedrock":
self.bedrock_runtime = boto3.client(service_name="bedrock-runtime")

self.mermaid_prompt = json.load(open("prompt/mermaid.json", "r"))

def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
# Convert image to base64
image_stream = io.BytesIO()
img.save(image_stream, format="JPEG")
base64_encoded = base64.b64encode(image_stream.getvalue()).decode(
"utf-8"
)

if self.model_provider == "bedrock":
return self._invoke_bedrock(base64_encoded, prompt, prefix, stop)
elif self.model_provider in ["openai", "siliconflow"]:
return self._invoke_openai_compatible(
base64_encoded, prompt, prefix, stop
)
else:
raise ValueError(
f"Unsupported model provider: {self.model_provider}"
)

def _invoke_bedrock(self, base64_encoded, prompt, prefix, stop):
"""Invoke Bedrock models"""
messages = [
{
"role": "user",
Expand All @@ -36,7 +103,7 @@ def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
},
{"role": "assistant", "content": prefix},
]
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
Expand All @@ -45,12 +112,70 @@ def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
"stop_sequences": [stop],
}
)
response = self.bedrock_runtime.invoke_model(
body=body, modelId=model_id
)
response_body = json.loads(response.get("body").read())
result = prefix + response_body["content"][0]["text"] + stop
return result

try:
response = self.bedrock_runtime.invoke_model(
body=body, modelId=self.model_id
)
response_body = json.loads(response.get("body").read())
result = prefix + response_body["content"][0]["text"] + stop
return result
except ClientError as e:
logging.error(f"Error invoking Bedrock model: {e}")
raise

def _invoke_openai_compatible(self, base64_encoded, prompt, prefix, stop):
"""Invoke OpenAI-compatible API (OpenAI, SiliconFlow, etc.)"""
import requests

if not self.api_url or not self.api_key:
raise ValueError(
"API URL and API key are required for OpenAI-compatible providers"
)

headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

payload = {
"model": self.model_id,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_encoded}",
},
},
],
},
{"role": "assistant", "content": prefix},
],
"max_tokens": 4096,
"stop": [stop],
}

try:
response = requests.post(
self.api_url, json=payload, headers=headers
)
response.raise_for_status()
response_data = response.json()

# Extract text from response based on OpenAI format
result = (
prefix
+ response_data["choices"][0]["message"]["content"]
+ stop
)
return result
except requests.exceptions.RequestException as e:
logging.error(f"Error invoking {self.model_provider} API: {e}")
raise

def get_classification(self, img):
with open("prompt/figure_classification.txt") as f:
Expand All @@ -59,37 +184,12 @@ def get_classification(self, img):
return output

def get_chart(self, img, context, tag):
prompt = """您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
1. 找到图片中的图表。
2. 仔细观察图表,了解其中包含的结构和数据。
3. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。
4. 按照以下指南将图表数据转换成 Markdown 表格格式:
- 使用 | 字符分隔列
- 使用 --- 行表示标题行
- 确保表格格式正确并对齐
- 对于不确定的数字,请根据图片估算。
5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。
6. 在 <output></output>xml 标签中仅返回 Markdown,不含其他文本。
<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
""".strip()
prompt = CHART_UNDERSTAND_PROMPT.strip()
output = self.invoke_llm(img, prompt)
return output

def get_description(self, img, context, tag):
prompt = """
你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行:
1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。
2. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。
3. 将你的描述写在<output></output>标签之间。
<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
""".strip()
prompt = IMAGE_DESCRIPTION_PROMPT.strip()

output = self.invoke_llm(img, prompt.format(context=context, tag=tag))
return f"![{output}]()"
Expand Down
Loading

0 comments on commit 4f7a531

Please sign in to comment.