Skip to content

Commit 4bc2ddd

Browse files
authored
[OpenAI] Add OpenAI-API-compatible function calling support (mlc-ai#321)
1 parent eaaff6a commit 4bc2ddd

11 files changed

+456
-119
lines changed

examples/openai-api/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
"url": "^0.11.3"
1616
},
1717
"dependencies": {
18-
"@mlc-ai/web-llm": "^0.2.28"
18+
"@mlc-ai/web-llm": "file:../.."
1919
}
2020
}

examples/openai-api/src/openai_api.ts

+57-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,62 @@ async function mainStateful() {
137137
console.log(await chat.runtimeStatsText());
138138
}
139139

140+
async function mainFunctionCalling() {
141+
const chat: webllm.ChatInterface = new webllm.ChatModule();
142+
143+
chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
144+
setLabel("init-label", report.text);
145+
});
146+
147+
const myAppConfig: webllm.AppConfig = {
148+
model_list: [
149+
{
150+
"model_url": "https://huggingface.co/mlc-ai/gorilla-openfunctions-v2-q4f16_1-MLC/resolve/main/",
151+
"local_id": "gorilla-openfunctions-v2-q4f16_1",
152+
"model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gorilla-openfunctions-v2/gorilla-openfunctions-v2-q4f16_1.wasm",
153+
},
154+
]
155+
}
156+
const selectedModel = "gorilla-openfunctions-v2-q4f16_1"
157+
await chat.reload(selectedModel, undefined, myAppConfig);
158+
159+
const tools: Array<webllm.ChatCompletionTool> = [
160+
{
161+
type: "function",
162+
function: {
163+
name: "get_current_weather",
164+
description: "Get the current weather in a given location",
165+
parameters: {
166+
"type": "object",
167+
"properties": {
168+
"location": {
169+
"type": "string",
170+
"description": "The city and state, e.g. San Francisco, CA",
171+
},
172+
"unit": { "type": "string", "enum": ["celsius", "fahrenheit"] },
173+
},
174+
"required": ["location"],
175+
},
176+
},
177+
}
178+
]
179+
180+
const request: webllm.ChatCompletionRequest = {
181+
stream: false,
182+
messages: [
183+
{ "role": "user", "content": "What is the current weather in celsius in Pittsburgh and Tokyo?" },
184+
],
185+
tool_choice: 'auto',
186+
tools: tools,
187+
};
188+
189+
const reply0 = await chat.chatCompletion(request);
190+
console.log(reply0.choices[0].message.content);
191+
192+
console.log(await chat.runtimeStatsText());
193+
}
194+
140195
// Run one of the functions
141-
mainNonStreaming();
196+
// mainNonStreaming();
142197
// mainStreaming();
143-
// mainStateful();
198+
mainFunctionCalling();

package-lock.json

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rollup.config.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { nodeResolve } from '@rollup/plugin-node-resolve';
22
import ignore from "rollup-plugin-ignore";
33
import commonjs from '@rollup/plugin-commonjs';
4-
import typescript from 'rollup-plugin-typescript2'
4+
import typescript from 'rollup-plugin-typescript2';
55

66
export default {
77
input: 'src/index.ts',

src/chat_module.ts

+57-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import {
77
prebuiltAppConfig,
88
GenerationConfig,
99
postInitAndCheckGenerationConfigValues,
10-
ModelRecord
10+
ModelRecord,
11+
Role
1112
} from "./config";
1213
import { LLMChatPipeline } from "./llm_chat"
1314
import {
@@ -20,6 +21,7 @@ import {
2021
ChatCompletionRequestStreaming,
2122
ChatCompletionRequestBase,
2223
CompletionUsage,
24+
ChatCompletionUserMessageParam,
2325
} from "./openai_api_protocols/index";
2426
import * as ChatCompletionAPI from "./openai_api_protocols/index";
2527
import {
@@ -316,6 +318,11 @@ export class ChatModule implements ChatInterface {
316318
top_logprobs: request.top_logprobs,
317319
}
318320

321+
const error_msg = this.checkFunctionCallUsage(request);
322+
if (error_msg) {
323+
throw new Error(error_msg);
324+
}
325+
319326
// 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`)
320327
if (request.stream) {
321328
return this.chatCompletionAsyncChunkGenerator(request, genConfig);
@@ -506,24 +513,65 @@ export class ChatModule implements ChatInterface {
506513
throw new Error("Last messages should be a string from the `user`.");
507514
}
508515
this.getPipeline().appendConversationMessage(
509-
message.name ? message.name : roles[0],
516+
Role.User,
510517
message.content,
518+
message.name
511519
);
512520
} else if (message.role === "assistant") {
513521
if (typeof message.content !== "string") {
514-
// TODO(Charlie): Remove when we support function calling
515522
throw new Error("Assistant message should have string content.");
516523
}
517524
this.getPipeline().appendConversationMessage(
518-
message.name ? message.name : roles[1],
525+
Role.Assistant,
519526
message.content,
527+
message.name
520528
);
521529
} else {
522530
throw new Error("Unsupported role: " + message.role);
523531
}
524532
}
525533
}
526534

535+
private checkFunctionCallUsage(request: ChatCompletionRequest): string | null {
536+
if (request.tools == undefined ||
537+
(typeof request.tool_choice == "string" && request.tool_choice == "none")) {
538+
this.getPipeline().overrideFunctionCalling(false, "");
539+
return null;
540+
}
541+
542+
if (typeof request.tool_choice == "string" && request.tool_choice !== "auto") {
543+
return `Invalid tool choice value: ${request.tool_choice}`;
544+
}
545+
546+
if (typeof request.tool_choice !== "string" && request.tool_choice?.type) {
547+
return "Only 'function' tool choice is supported";
548+
}
549+
550+
const singleFunctionToCall = typeof request.tool_choice !== "string" && request.tool_choice?.function?.name;
551+
552+
if (singleFunctionToCall) {
553+
for (const f of request.tools) {
554+
if (singleFunctionToCall == f.function.name) {
555+
this.getPipeline().overrideFunctionCalling(true, JSON.stringify([f.function]));
556+
return null;
557+
}
558+
}
559+
560+
return `The tool choice function ${singleFunctionToCall} is not found in the tools list`;
561+
}
562+
563+
let function_list = [];
564+
for (const f of request.tools) {
565+
if (f.type !== "function") {
566+
return "Only 'function' tool type is supported";
567+
}
568+
569+
function_list.push(f.function);
570+
}
571+
this.getPipeline().overrideFunctionCalling(true, JSON.stringify(function_list));
572+
return null;
573+
}
574+
527575
/**
528576
* Run a prefill step with a given input.
529577
* @param input The input prompt, or `messages` in OpenAI-like APIs.
@@ -533,15 +581,18 @@ export class ChatModule implements ChatInterface {
533581
genConfig?: GenerationConfig
534582
) {
535583
let input_str: string;
584+
let input_role_str : string | undefined;
536585
if (typeof input === "string") {
537586
input_str = input;
538587
} else {
539588
// Process ChatCompletionMessageParam
540589
// We treat the last message as our usual input
541590
this.updateConversationWithChatCompletionMessages(input);
542-
input_str = input[input.length - 1].content as string;
591+
const last_msg = input[input.length - 1] as ChatCompletionUserMessageParam;
592+
input_str = last_msg.content as string;
593+
input_role_str = last_msg.name ? last_msg.name : undefined;
543594
}
544-
return this.getPipeline().prefillStep(input_str, genConfig);
595+
return this.getPipeline().prefillStep(input_str, input_role_str, genConfig);
545596
}
546597

547598
/**

src/config.ts

+23-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
*/
55
export interface ConvTemplateConfig {
66
system: string;
7-
roles: Array<string>;
7+
roles: Record<Role, string>;
8+
role_templates?: Partial<Record<Role, string>>;
9+
function_calling_template?: string;
810
seps: Array<string>;
911
separator_style: string;
1012
offset: number;
@@ -13,6 +15,26 @@ export interface ConvTemplateConfig {
1315
stop_tokens: Array<number>;
1416
}
1517

18+
export enum Role {
19+
User,
20+
Assistant
21+
}
22+
23+
/**
24+
* Place holders that can be used in role templates.
25+
* For example, a role template of
26+
* `<<question>> ${MessagePlaceholders.USER} <<function>> ${MessagePlaceholders.FUNCTION}`
27+
* will insert the user message to ${MessagePlaceholders.USER}
28+
* and insert the function message to ${MessagePlaceholders.FUNCTION}
29+
* at run time.
30+
*/
31+
export enum MessagePlaceholders {
32+
User = "{user_message}",
33+
Assitant = "{assistant_message}",
34+
Tool = "{tool_message}",
35+
Function = "{function_string}"
36+
}
37+
1638
/**
1739
* Config of one chat model, a data structure representing `mlc-chat-config.json`.
1840
* This only corresponds to the chat-related fields and `tokenizer_files` of `mlc-chat-config.json`.

0 commit comments

Comments
 (0)