7
7
prebuiltAppConfig ,
8
8
GenerationConfig ,
9
9
postInitAndCheckGenerationConfigValues ,
10
- ModelRecord
10
+ ModelRecord ,
11
+ Role
11
12
} from "./config" ;
12
13
import { LLMChatPipeline } from "./llm_chat"
13
14
import {
@@ -20,6 +21,7 @@ import {
20
21
ChatCompletionRequestStreaming ,
21
22
ChatCompletionRequestBase ,
22
23
CompletionUsage ,
24
+ ChatCompletionUserMessageParam ,
23
25
} from "./openai_api_protocols/index" ;
24
26
import * as ChatCompletionAPI from "./openai_api_protocols/index" ;
25
27
import {
@@ -316,6 +318,11 @@ export class ChatModule implements ChatInterface {
316
318
top_logprobs : request . top_logprobs ,
317
319
}
318
320
321
+ const error_msg = this . checkFunctionCallUsage ( request ) ;
322
+ if ( error_msg ) {
323
+ throw new Error ( error_msg ) ;
324
+ }
325
+
319
326
// 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`)
320
327
if ( request . stream ) {
321
328
return this . chatCompletionAsyncChunkGenerator ( request , genConfig ) ;
@@ -506,24 +513,65 @@ export class ChatModule implements ChatInterface {
506
513
throw new Error ( "Last messages should be a string from the `user`." ) ;
507
514
}
508
515
this . getPipeline ( ) . appendConversationMessage (
509
- message . name ? message . name : roles [ 0 ] ,
516
+ Role . User ,
510
517
message . content ,
518
+ message . name
511
519
) ;
512
520
} else if ( message . role === "assistant" ) {
513
521
if ( typeof message . content !== "string" ) {
514
- // TODO(Charlie): Remove when we support function calling
515
522
throw new Error ( "Assistant message should have string content." ) ;
516
523
}
517
524
this . getPipeline ( ) . appendConversationMessage (
518
- message . name ? message . name : roles [ 1 ] ,
525
+ Role . Assistant ,
519
526
message . content ,
527
+ message . name
520
528
) ;
521
529
} else {
522
530
throw new Error ( "Unsupported role: " + message . role ) ;
523
531
}
524
532
}
525
533
}
526
534
535
+ private checkFunctionCallUsage ( request : ChatCompletionRequest ) : string | null {
536
+ if ( request . tools == undefined ||
537
+ ( typeof request . tool_choice == "string" && request . tool_choice == "none" ) ) {
538
+ this . getPipeline ( ) . overrideFunctionCalling ( false , "" ) ;
539
+ return null ;
540
+ }
541
+
542
+ if ( typeof request . tool_choice == "string" && request . tool_choice !== "auto" ) {
543
+ return `Invalid tool choice value: ${ request . tool_choice } ` ;
544
+ }
545
+
546
+ if ( typeof request . tool_choice !== "string" && request . tool_choice ?. type ) {
547
+ return "Only 'function' tool choice is supported" ;
548
+ }
549
+
550
+ const singleFunctionToCall = typeof request . tool_choice !== "string" && request . tool_choice ?. function ?. name ;
551
+
552
+ if ( singleFunctionToCall ) {
553
+ for ( const f of request . tools ) {
554
+ if ( singleFunctionToCall == f . function . name ) {
555
+ this . getPipeline ( ) . overrideFunctionCalling ( true , JSON . stringify ( [ f . function ] ) ) ;
556
+ return null ;
557
+ }
558
+ }
559
+
560
+ return `The tool choice function ${ singleFunctionToCall } is not found in the tools list` ;
561
+ }
562
+
563
+ let function_list = [ ] ;
564
+ for ( const f of request . tools ) {
565
+ if ( f . type !== "function" ) {
566
+ return "Only 'function' tool type is supported" ;
567
+ }
568
+
569
+ function_list . push ( f . function ) ;
570
+ }
571
+ this . getPipeline ( ) . overrideFunctionCalling ( true , JSON . stringify ( function_list ) ) ;
572
+ return null ;
573
+ }
574
+
527
575
/**
528
576
* Run a prefill step with a given input.
529
577
* @param input The input prompt, or `messages` in OpenAI-like APIs.
@@ -533,15 +581,18 @@ export class ChatModule implements ChatInterface {
533
581
genConfig ?: GenerationConfig
534
582
) {
535
583
let input_str : string ;
584
+ let input_role_str : string | undefined ;
536
585
if ( typeof input === "string" ) {
537
586
input_str = input ;
538
587
} else {
539
588
// Process ChatCompletionMessageParam
540
589
// We treat the last message as our usual input
541
590
this . updateConversationWithChatCompletionMessages ( input ) ;
542
- input_str = input [ input . length - 1 ] . content as string ;
591
+ const last_msg = input [ input . length - 1 ] as ChatCompletionUserMessageParam ;
592
+ input_str = last_msg . content as string ;
593
+ input_role_str = last_msg . name ? last_msg . name : undefined ;
543
594
}
544
- return this . getPipeline ( ) . prefillStep ( input_str , genConfig ) ;
595
+ return this . getPipeline ( ) . prefillStep ( input_str , input_role_str , genConfig ) ;
545
596
}
546
597
547
598
/**
0 commit comments