Skip to content

Commit a1d2aa9

Browse files
charles-marionazaylambabigadsoleiman
authored
feat: Disable Sagemaker endpoint (or cross-encoder per workspace) (#588)
--------- Co-authored-by: Ajay Lamba <[email protected]> Co-authored-by: Bigad Soleiman <[email protected]>
1 parent b6a5d5a commit a1d2aa9

File tree

35 files changed

+468
-289
lines changed

35 files changed

+468
-289
lines changed

Diff for: bin/config.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ import { existsSync, readFileSync } from "fs";
33

44
export function getConfig(): SystemConfig {
55
if (existsSync("./bin/config.json")) {
6-
return JSON.parse(readFileSync("./bin/config.json").toString("utf8"));
6+
return JSON.parse(
7+
readFileSync("./bin/config.json").toString("utf8")
8+
) as SystemConfig;
79
}
810
// Default config
911
return {
@@ -48,11 +50,13 @@ export function getConfig(): SystemConfig {
4850
provider: "sagemaker",
4951
name: "intfloat/multilingual-e5-large",
5052
dimensions: 1024,
53+
default: false,
5154
},
5255
{
5356
provider: "sagemaker",
5457
name: "sentence-transformers/all-MiniLM-L6-v2",
5558
dimensions: 384,
59+
default: false,
5660
},
5761
{
5862
provider: "bedrock",
@@ -80,8 +84,10 @@ export function getConfig(): SystemConfig {
8084
provider: "openai",
8185
name: "text-embedding-ada-002",
8286
dimensions: 1536,
87+
default: false,
8388
},
8489
],
90+
crossEncodingEnabled: false,
8591
crossEncoderModels: [
8692
{
8793
provider: "sagemaker",

Diff for: cli/magic-config.ts

+55-18
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
SupportedSageMakerModels,
1111
SystemConfig,
1212
SupportedBedrockRegion,
13+
ModelConfig,
1314
} from "../lib/shared/types";
1415
import { LIB_VERSION } from "./version.js";
1516
import * as fs from "fs";
@@ -34,7 +35,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
3435
function getCountryCodesAndNames(): { message: string; name: string }[] {
3536
// Use country-list to get an array of countries with their codes and names
3637
const countries = getData();
37-
3838
// Map the country data to match the desired output structure
3939
const countryInfo = countries.map(({ code, name }) => {
4040
return { message: `${name} (${code})`, name: code };
@@ -88,21 +88,24 @@ const secretManagerArnRegExp = RegExp(
8888
/arn:aws:secretsmanager:[\w-_]+:\d+:secret:[\w-_]+/
8989
);
9090

91-
const embeddingModels = [
91+
const embeddingModels: ModelConfig[] = [
9292
{
9393
provider: "sagemaker",
9494
name: "intfloat/multilingual-e5-large",
9595
dimensions: 1024,
96+
default: false,
9697
},
9798
{
9899
provider: "sagemaker",
99100
name: "sentence-transformers/all-MiniLM-L6-v2",
100101
dimensions: 384,
102+
default: false,
101103
},
102104
{
103105
provider: "bedrock",
104106
name: "amazon.titan-embed-text-v1",
105107
dimensions: 1536,
108+
default: false,
106109
},
107110
//Support for inputImage is not yet implemented for amazon.titan-embed-image-v1
108111
{
@@ -124,6 +127,7 @@ const embeddingModels = [
124127
provider: "openai",
125128
name: "text-embedding-ada-002",
126129
dimensions: 1536,
130+
default: false,
127131
},
128132
];
129133

@@ -179,6 +183,8 @@ const embeddingModels = [
179183
options.startScheduleEndDate =
180184
config.llms?.sagemakerSchedule?.startScheduleEndDate;
181185
options.enableRag = config.rag.enabled;
186+
options.deployDefaultSagemakerModels =
187+
config.rag.deployDefaultSagemakerModels;
182188
options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter(
183189
(v: string) =>
184190
(
@@ -608,6 +614,16 @@ async function processCreateOptions(options: any): Promise<void> {
608614
message: "Do you want to enable RAG",
609615
initial: options.enableRag || false,
610616
},
617+
{
618+
type: "confirm",
619+
name: "deployDefaultSagemakerModels",
620+
message:
621+
"Do you want to deploy the default embedding and cross-encoder models via SageMaker?",
622+
initial: options.deployDefaultSagemakerModels || false,
623+
skip(): boolean {
624+
return !(this as any).state.answers.enableRag;
625+
},
626+
},
611627
{
612628
type: "multiselect",
613629
name: "ragsToEnable",
@@ -810,10 +826,17 @@ async function processCreateOptions(options: any): Promise<void> {
810826
choices: embeddingModels.map((m) => ({ name: m.name, value: m })),
811827
initial: options.defaultEmbedding,
812828
validate(value: string) {
829+
const embeding = embeddingModels.find((i) => i.name === value);
830+
if (
831+
embeding &&
832+
(this as any).state.answers.deployDefaultSagemakerModels === false &&
833+
embeding?.provider === "sagemaker"
834+
) {
835+
return "SageMaker default models are not enabled. Please select another model.";
836+
}
813837
if ((this as any).state.answers.enableRag) {
814838
return value ? true : "Select a default embedding model";
815839
}
816-
817840
return true;
818841
},
819842
skip() {
@@ -1219,6 +1242,7 @@ async function processCreateOptions(options: any): Promise<void> {
12191242
}
12201243
: undefined,
12211244
llms: {
1245+
enableSagemakerModels: answers.enableSagemakerModels,
12221246
rateLimitPerAIP: advancedSettings?.llmRateLimitPerIP
12231247
? Number(advancedSettings?.llmRateLimitPerIP)
12241248
: undefined,
@@ -1241,6 +1265,7 @@ async function processCreateOptions(options: any): Promise<void> {
12411265
},
12421266
rag: {
12431267
enabled: answers.enableRag,
1268+
deployDefaultSagemakerModels: answers.deployDefaultSagemakerModels,
12441269
engines: {
12451270
aurora: {
12461271
enabled: answers.ragsToEnable.includes("aurora"),
@@ -1259,28 +1284,40 @@ async function processCreateOptions(options: any): Promise<void> {
12591284
external: [{}],
12601285
},
12611286
},
1262-
embeddingsModels: [{}],
1263-
crossEncoderModels: [{}],
1287+
embeddingsModels: [] as ModelConfig[],
1288+
crossEncoderModels: [] as ModelConfig[],
12641289
},
12651290
};
12661291

1292+
if (config.rag.enabled && config.rag.deployDefaultSagemakerModels) {
1293+
config.rag.crossEncoderModels[0] = {
1294+
provider: "sagemaker",
1295+
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
1296+
default: true,
1297+
};
1298+
config.rag.embeddingsModels = embeddingModels;
1299+
} else if (config.rag.enabled) {
1300+
config.rag.embeddingsModels = embeddingModels.filter(
1301+
(model) => model.provider !== "sagemaker"
1302+
);
1303+
for (const model of config.rag.embeddingsModels) {
1304+
model.default = model.name === models.defaultEmbedding;
1305+
}
1306+
} else {
1307+
config.rag.embeddingsModels = [];
1308+
}
1309+
12671310
// If we have not enabled rag the default embedding is set to the first model
12681311
if (!answers.enableRag) {
1269-
models.defaultEmbedding = embeddingModels[0].name;
1312+
(config.rag.embeddingsModels[0] as any).default = true;
1313+
} else {
1314+
config.rag.embeddingsModels.forEach((m: any) => {
1315+
if (m.name === models.defaultEmbedding) {
1316+
m.default = true;
1317+
}
1318+
});
12701319
}
12711320

1272-
config.rag.crossEncoderModels[0] = {
1273-
provider: "sagemaker",
1274-
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
1275-
default: true,
1276-
};
1277-
config.rag.embeddingsModels = embeddingModels;
1278-
config.rag.embeddingsModels.forEach((m: any) => {
1279-
if (m.name === models.defaultEmbedding) {
1280-
m.default = true;
1281-
}
1282-
});
1283-
12841321
config.rag.engines.kendra.createIndex =
12851322
answers.ragsToEnable.includes("kendra");
12861323
config.rag.engines.kendra.enabled =

Diff for: integtests/chatbot-api/aurora_workspace_test.py

+56-19
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def run_before_and_after_tests(client: AppSyncClient):
1010
for workspace in client.list_workspaces():
1111
if (
1212
workspace.get("name") == "INTEG_TEST_AURORA"
13-
and workspace.get("status") == "ready"
14-
):
13+
or workspace.get("name") == "INTEG_TEST_AURORA_WITHOUT_RERANK"
14+
) and workspace.get("status") == "ready":
1515
client.delete_workspace(workspace.get("id"))
1616

1717

@@ -22,23 +22,25 @@ def test_create(client: AppSyncClient, default_embed_model):
2222
if engine.get("enabled") == False:
2323
pytest.skip_flag = True
2424
pytest.skip("Aurora is not enabled.")
25-
pytest.workspace = client.create_aurora_workspace(
26-
input={
27-
"kind": "auro2",
28-
"name": "INTEG_TEST_AURORA",
29-
"embeddingsModelProvider": "bedrock",
30-
"embeddingsModelName": default_embed_model,
31-
"crossEncoderModelName": "cross-encoder/ms-marco-MiniLM-L-12-v2",
32-
"crossEncoderModelProvider": "sagemaker",
33-
"languages": ["english"],
34-
"index": True,
35-
"hybridSearch": True,
36-
"metric": "inner",
37-
"chunkingStrategy": "recursive",
38-
"chunkSize": 1000,
39-
"chunkOverlap": 200,
40-
}
41-
)
25+
input = {
26+
"kind": "auro2",
27+
"name": "INTEG_TEST_AURORA_WITHOUT_RERANK",
28+
"embeddingsModelProvider": "bedrock",
29+
"embeddingsModelName": default_embed_model,
30+
"languages": ["english"],
31+
"index": True,
32+
"hybridSearch": True,
33+
"metric": "inner",
34+
"chunkingStrategy": "recursive",
35+
"chunkSize": 1000,
36+
"chunkOverlap": 200,
37+
}
38+
input_with_rerank = input.copy()
39+
input_with_rerank["name"] = "INTEG_TEST_AURORA"
40+
input_with_rerank["crossEncoderModelName"] = "cross-encoder/ms-marco-MiniLM-L-12-v2"
41+
input_with_rerank["crossEncoderModelProvider"] = "sagemaker"
42+
pytest.workspace = client.create_aurora_workspace(input=input_with_rerank)
43+
pytest.workspace_no_re_rank = client.create_aurora_workspace(input=input)
4244

4345
ready = False
4446
retries = 0
@@ -56,6 +58,7 @@ def test_create(client: AppSyncClient, default_embed_model):
5658
def test_add_rss(client: AppSyncClient):
5759
if pytest.skip_flag == True:
5860
pytest.skip("Aurora is not enabled.")
61+
5962
pytest.document = client.add_rss_feed(
6063
input={
6164
"workspaceId": pytest.workspace.get("id"),
@@ -67,6 +70,17 @@ def test_add_rss(client: AppSyncClient):
6770
"limit": 2,
6871
}
6972
)
73+
client.add_rss_feed(
74+
input={
75+
"workspaceId": pytest.workspace_no_re_rank.get("id"),
76+
"title": "INTEG_TEST_AURORA_TITLE",
77+
"address": "https://github.com/aws-samples/aws-genai-llm-chatbot/"
78+
+ "releases.atom",
79+
"contentTypes": ["text/html"],
80+
"followLinks": True,
81+
"limit": 2,
82+
}
83+
)
7084

7185
ready = False
7286
retries = 0
@@ -137,6 +151,29 @@ def test_search_document(client: AppSyncClient):
137151
assert ready == True
138152

139153

154+
def test_search_document_no_reank(client: AppSyncClient):
155+
if pytest.skip_flag == True:
156+
pytest.skip("Aurora is not enabled.")
157+
ready = False
158+
retries = 0
159+
# Wait for the page to be crawled. This starts on a cron every 5 min.
160+
while not ready and retries < 50:
161+
time.sleep(15)
162+
retries += 1
163+
result = client.semantic_search(
164+
input={
165+
"workspaceId": pytest.workspace_no_re_rank.get("id"),
166+
"query": "Release github",
167+
}
168+
)
169+
if len(result.get("items")) > 1:
170+
ready = True
171+
assert result.get("engine") == "aurora"
172+
# Re-ranking score is no set but the results are ordered by Aurora.
173+
assert result.get("items")[0].get("score") is None
174+
assert ready == True
175+
176+
140177
def test_query_llm(client, default_model, default_provider):
141178
if pytest.skip_flag == True:
142179
pytest.skip("Aurora is not enabled.")

0 commit comments

Comments
 (0)