@@ -10,6 +10,7 @@ import {
10
10
SupportedSageMakerModels ,
11
11
SystemConfig ,
12
12
SupportedBedrockRegion ,
13
+ ModelConfig ,
13
14
} from "../lib/shared/types" ;
14
15
import { LIB_VERSION } from "./version.js" ;
15
16
import * as fs from "fs" ;
@@ -34,7 +35,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
34
35
function getCountryCodesAndNames ( ) : { message: string ; name: string } [ ] {
35
36
// Use country-list to get an array of countries with their codes and names
36
37
const countries = getData ( ) ;
37
-
38
38
// Map the country data to match the desired output structure
39
39
const countryInfo = countries . map ( ( { code, name } ) => {
40
40
return { message : `${ name } (${ code } )` , name : code } ;
@@ -88,21 +88,24 @@ const secretManagerArnRegExp = RegExp(
88
88
/ a r n : a w s : s e c r e t s m a n a g e r : [ \w - _ ] + : \d + : s e c r e t : [ \w - _ ] + /
89
89
) ;
90
90
91
- const embeddingModels = [
91
+ const embeddingModels : ModelConfig [ ] = [
92
92
{
93
93
provider : "sagemaker" ,
94
94
name : "intfloat/multilingual-e5-large" ,
95
95
dimensions : 1024 ,
96
+ default : false ,
96
97
} ,
97
98
{
98
99
provider : "sagemaker" ,
99
100
name : "sentence-transformers/all-MiniLM-L6-v2" ,
100
101
dimensions : 384 ,
102
+ default : false ,
101
103
} ,
102
104
{
103
105
provider : "bedrock" ,
104
106
name : "amazon.titan-embed-text-v1" ,
105
107
dimensions : 1536 ,
108
+ default : false ,
106
109
} ,
107
110
//Support for inputImage is not yet implemented for amazon.titan-embed-image-v1
108
111
{
@@ -124,6 +127,7 @@ const embeddingModels = [
124
127
provider : "openai" ,
125
128
name : "text-embedding-ada-002" ,
126
129
dimensions : 1536 ,
130
+ default : false ,
127
131
} ,
128
132
] ;
129
133
@@ -179,6 +183,8 @@ const embeddingModels = [
179
183
options . startScheduleEndDate =
180
184
config . llms ?. sagemakerSchedule ?. startScheduleEndDate ;
181
185
options . enableRag = config . rag . enabled ;
186
+ options . deployDefaultSagemakerModels =
187
+ config . rag . deployDefaultSagemakerModels ;
182
188
options . ragsToEnable = Object . keys ( config . rag . engines ?? { } ) . filter (
183
189
( v : string ) =>
184
190
(
@@ -608,6 +614,16 @@ async function processCreateOptions(options: any): Promise<void> {
608
614
message : "Do you want to enable RAG" ,
609
615
initial : options . enableRag || false ,
610
616
} ,
617
+ {
618
+ type : "confirm" ,
619
+ name : "deployDefaultSagemakerModels" ,
620
+ message :
621
+ "Do you want to deploy the default embedding and cross-encoder models via SageMaker?" ,
622
+ initial : options . deployDefaultSagemakerModels || false ,
623
+ skip ( ) : boolean {
624
+ return ! ( this as any ) . state . answers . enableRag ;
625
+ } ,
626
+ } ,
611
627
{
612
628
type : "multiselect" ,
613
629
name : "ragsToEnable" ,
@@ -810,10 +826,17 @@ async function processCreateOptions(options: any): Promise<void> {
810
826
choices : embeddingModels . map ( ( m ) => ( { name : m . name , value : m } ) ) ,
811
827
initial : options . defaultEmbedding ,
812
828
validate ( value : string ) {
829
+ const embeding = embeddingModels . find ( ( i ) => i . name === value ) ;
830
+ if (
831
+ embeding &&
832
+ ( this as any ) . state . answers . deployDefaultSagemakerModels === false &&
833
+ embeding ?. provider === "sagemaker"
834
+ ) {
835
+ return "SageMaker default models are not enabled. Please select another model." ;
836
+ }
813
837
if ( ( this as any ) . state . answers . enableRag ) {
814
838
return value ? true : "Select a default embedding model" ;
815
839
}
816
-
817
840
return true ;
818
841
} ,
819
842
skip ( ) {
@@ -1219,6 +1242,7 @@ async function processCreateOptions(options: any): Promise<void> {
1219
1242
}
1220
1243
: undefined ,
1221
1244
llms : {
1245
+ enableSagemakerModels : answers . enableSagemakerModels ,
1222
1246
rateLimitPerAIP : advancedSettings ?. llmRateLimitPerIP
1223
1247
? Number ( advancedSettings ?. llmRateLimitPerIP )
1224
1248
: undefined ,
@@ -1241,6 +1265,7 @@ async function processCreateOptions(options: any): Promise<void> {
1241
1265
} ,
1242
1266
rag : {
1243
1267
enabled : answers . enableRag ,
1268
+ deployDefaultSagemakerModels : answers . deployDefaultSagemakerModels ,
1244
1269
engines : {
1245
1270
aurora : {
1246
1271
enabled : answers . ragsToEnable . includes ( "aurora" ) ,
@@ -1259,28 +1284,40 @@ async function processCreateOptions(options: any): Promise<void> {
1259
1284
external : [ { } ] ,
1260
1285
} ,
1261
1286
} ,
1262
- embeddingsModels : [ { } ] ,
1263
- crossEncoderModels : [ { } ] ,
1287
+ embeddingsModels : [ ] as ModelConfig [ ] ,
1288
+ crossEncoderModels : [ ] as ModelConfig [ ] ,
1264
1289
} ,
1265
1290
} ;
1266
1291
1292
+ if ( config . rag . enabled && config . rag . deployDefaultSagemakerModels ) {
1293
+ config . rag . crossEncoderModels [ 0 ] = {
1294
+ provider : "sagemaker" ,
1295
+ name : "cross-encoder/ms-marco-MiniLM-L-12-v2" ,
1296
+ default : true ,
1297
+ } ;
1298
+ config . rag . embeddingsModels = embeddingModels ;
1299
+ } else if ( config . rag . enabled ) {
1300
+ config . rag . embeddingsModels = embeddingModels . filter (
1301
+ ( model ) => model . provider !== "sagemaker"
1302
+ ) ;
1303
+ for ( const model of config . rag . embeddingsModels ) {
1304
+ model . default = model . name === models . defaultEmbedding ;
1305
+ }
1306
+ } else {
1307
+ config . rag . embeddingsModels = [ ] ;
1308
+ }
1309
+
1267
1310
// If we have not enabled rag the default embedding is set to the first model
1268
1311
if ( ! answers . enableRag ) {
1269
- models . defaultEmbedding = embeddingModels [ 0 ] . name ;
1312
+ ( config . rag . embeddingsModels [ 0 ] as any ) . default = true ;
1313
+ } else {
1314
+ config . rag . embeddingsModels . forEach ( ( m : any ) => {
1315
+ if ( m . name === models . defaultEmbedding ) {
1316
+ m . default = true ;
1317
+ }
1318
+ } ) ;
1270
1319
}
1271
1320
1272
- config . rag . crossEncoderModels [ 0 ] = {
1273
- provider : "sagemaker" ,
1274
- name : "cross-encoder/ms-marco-MiniLM-L-12-v2" ,
1275
- default : true ,
1276
- } ;
1277
- config . rag . embeddingsModels = embeddingModels ;
1278
- config . rag . embeddingsModels . forEach ( ( m : any ) => {
1279
- if ( m . name === models . defaultEmbedding ) {
1280
- m . default = true ;
1281
- }
1282
- } ) ;
1283
-
1284
1321
config . rag . engines . kendra . createIndex =
1285
1322
answers . ragsToEnable . includes ( "kendra" ) ;
1286
1323
config . rag . engines . kendra . enabled =
0 commit comments