Skip to content

Commit 9ffc683

Browse files
authored
[OpenAI] Support seed for reproducible generation (mlc-ai#335)
We support `seed` in `ChatCompletionRequest` so requests' results are reproducible. As stated in the docstring in `src/openai_api_protocols/chat_completion.ts`, seeding is done at a request level, rather than a choice level. So if a request with `n > 1` is seeded, the choices would still have different results. But if two requests, both with `n > 1` and share the same seed, would generate identical results, across all choices. This is demonstrated in `examples/openai-api/src/seed.ts`, where we rigorously compare the strings generated. Implementation wise, this is achieved with a customized implementation of linear congruential generator in TVMjs's runtime, since JS's `Math.random()` does not support seeding.
1 parent b5edea6 commit 9ffc683

File tree

7 files changed

+166
-10
lines changed

7 files changed

+166
-10
lines changed

examples/openai-api/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
### OpenAI API Demos
2+
3+
Run `npm install` first, followed by `npm start`.
4+
5+
To run different scripts, you can modify `package.json` from the default
6+
```json
7+
"scripts": {
8+
"start": "parcel src/openai_api.html --port 8888",
9+
"build": "parcel build src/openai_api.html --dist-dir lib"
10+
},
11+
```
12+
13+
to, say
14+
```json
15+
"scripts": {
16+
"start": "parcel src/seed.html --port 8888",
17+
"build": "parcel build src/seed.html --dist-dir lib"
18+
},
19+
```

examples/openai-api/src/seed.html

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<script>
4+
webLLMGlobal = {}
5+
</script>
6+
7+
<body>
8+
<h2>WebLLM OpenAI-like API Test Page</h2>
9+
Open console to see output.
10+
We make two generations with same seed, we should expect them to be the same.
11+
</br>
12+
</br>
13+
<label id="init-label"> </label>
14+
15+
<script type="module" src="./seed.ts"></script>
16+
17+
</html>

examples/openai-api/src/seed.ts

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import * as webllm from "@mlc-ai/web-llm";
2+
3+
function setLabel(id: string, text: string) {
4+
const label = document.getElementById(id);
5+
if (label == null) {
6+
throw Error("Cannot find label " + id);
7+
}
8+
label.innerText = text;
9+
}
10+
11+
/**
12+
* We domnstrate the effect of seeding. The prompt is about writing a poem and we use a high
13+
* `temperature`, making the sampling distribution supposedly more random. However, we demonstrate
14+
* that with seeding, we should see the exact same result being generated across two trials.
15+
* With `n > 1`, all choices should also be exactly the same.
16+
*/
17+
async function demonstrateSeed() {
18+
const chat: webllm.ChatInterface = new webllm.ChatModule();
19+
20+
chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
21+
setLabel("init-label", report.text);
22+
});
23+
24+
await chat.reload("Llama-2-7b-chat-hf-q4f32_1");
25+
26+
const request: webllm.ChatCompletionRequest = {
27+
stream: false, // works with streaming as well
28+
messages: [
29+
{ "role": "user", "content": "Write a creative Haiku about Pittsburgh" }
30+
],
31+
n: 3,
32+
temperature: 1.2, // high temperature gives much more random results
33+
seed: 42,
34+
max_gen_len: 128, // To save time; enough to demonstrate the effect
35+
};
36+
37+
const reply0 = await chat.chatCompletion(request);
38+
console.log(reply0);
39+
console.log("First reply's last choice:\n" + await chat.getMessage());
40+
41+
const reply1 = await chat.chatCompletion(request);
42+
console.log(reply1);
43+
console.log("Second reply's last choice:\n" + await chat.getMessage());
44+
45+
// Rigorously check the generation results of each choice for the two requests
46+
for (const choice0 of reply0.choices) {
47+
const id = choice0.index;
48+
const choice1 = reply1.choices[id];
49+
if (choice0.message.content !== choice1.message.content) {
50+
throw Error("Chocie " + id + " of the two generations are different despite seeding");
51+
}
52+
}
53+
54+
console.log(await chat.runtimeStatsText());
55+
}
56+
57+
// Run one of the functions
58+
demonstrateSeed();

src/chat_module.ts

+26
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ export class ChatModule implements ChatInterface {
208208
genConfig: GenerationConfig
209209
): AsyncGenerator<ChatCompletionChunk, void, void> {
210210
postInitAndCheckGenerationConfigValues(genConfig);
211+
if (request.seed !== null && request.seed !== undefined) {
212+
this.getPipeline().setSeed(request.seed);
213+
}
211214
if (!request.stateful) {
212215
await this.resetChat();
213216
}
@@ -255,6 +258,11 @@ export class ChatModule implements ChatInterface {
255258
yield await _getChunk(this);
256259
}
257260

261+
// Reset seed -- we do not want this seed to affect future requests
262+
if (request.seed !== null && request.seed !== undefined) {
263+
this.getPipeline().setSeed(Date.now());
264+
}
265+
258266
const lastChunk: ChatCompletionChunk = {
259267
id: id,
260268
choices: [{
@@ -270,6 +278,15 @@ export class ChatModule implements ChatInterface {
270278
yield lastChunk;
271279
}
272280

281+
/**
282+
* Completes a single ChatCompletionRequest.
283+
*
284+
* @param request A OpenAI-style ChatCompletion request.
285+
*
286+
* @note For each choice (i.e. `n`), a request is defined by a single `prefill()` and mulitple
287+
* `decode()`. This is important as it determines the behavior of various fields including
288+
* `stateful` and `seed`.
289+
*/
273290
async chatCompletion(
274291
request: ChatCompletionRequestNonStreaming
275292
): Promise<ChatCompletion>;
@@ -304,6 +321,10 @@ export class ChatModule implements ChatInterface {
304321
return this.chatCompletionAsyncChunkGenerator(request, genConfig);
305322
}
306323

324+
if (request.seed !== null && request.seed !== undefined) {
325+
this.getPipeline().setSeed(request.seed);
326+
}
327+
307328
// 2. If request is non-streaming, directly reuse `generate()`
308329
const n = request.n ? request.n : 1;
309330
const choices: Array<ChatCompletion.Choice> = [];
@@ -354,6 +375,11 @@ export class ChatModule implements ChatInterface {
354375
total_tokens: completion_tokens + prompt_tokens,
355376
} as CompletionUsage,
356377
}
378+
379+
// Reset seed -- we do not want this seed to affect future requests
380+
if (request.seed !== null && request.seed !== undefined) {
381+
this.getPipeline().setSeed(Date.now());
382+
}
357383
return response;
358384
}
359385

src/llm_chat.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,14 @@ export class LLMChatPipeline {
343343
)
344344
}
345345

346-
// Getters and writters for this.conversation.
346+
/**
347+
* Set the seed for the RNG `this.tvm.rng`.
348+
*/
349+
setSeed(seed: number): void {
350+
this.tvm.setSeed(seed);
351+
}
352+
353+
// Getters and setters for this.conversation.
347354
/**
348355
* Overrides the system prompt.
349356
*/

src/openai_api_protocols/chat_completion.ts

+17-9
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ export interface ChatCompletionRequestBase {
134134
*/
135135
top_logprobs?: number | null;
136136

137+
/**
138+
* If specified, our system will make a best effort to sample deterministically, such that
139+
* repeated requests with the same `seed` and parameters should return the same result.
140+
*
141+
* @note Seeding is done on a request-level rather than choice-level. That is, if `n > 1`, you
142+
* would still get different content for each `Chocie`. But if two requests with `n = 2` are
143+
* processed with the same seed, the two results should be the same (two choices are different).
144+
*/
145+
seed?: number | null;
146+
137147
//////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////
138148

139149
/**
@@ -143,14 +153,6 @@ export interface ChatCompletionRequestBase {
143153
*/
144154
model?: string | null;
145155

146-
/**
147-
* If specified, our system will make a best effort to sample deterministically, such that
148-
* repeated requests with the same `seed` and parameters should return the same result.
149-
*
150-
* @note Not supported yet.
151-
*/
152-
seed?: number | null;
153-
154156
/**
155157
* Controls which (if any) function is called by the model. `none` means the model
156158
* will not call a function and instead generates a message. `auto` means the model
@@ -306,7 +308,6 @@ export const ChatCompletionRequestUnsupportedFields: Array<string> = [
306308
"tool_choice",
307309
"tools",
308310
"response_format",
309-
"seed",
310311
];
311312

312313
export function postInitAndCheckFields(request: ChatCompletionRequest): void {
@@ -363,6 +364,13 @@ export function postInitAndCheckFields(request: ChatCompletionRequest): void {
363364
if (request.stateful && request.n && request.n > 1) {
364365
throw new Error("If the request is stateful, `n` cannot be > 1.");
365366
}
367+
368+
// 6. Seed should be an integer
369+
if (request.seed !== undefined && request.seed !== null) {
370+
if (!Number.isInteger(request.seed)) {
371+
throw new Error("`seed` should be an integer, but got " + request.seed);
372+
}
373+
}
366374
}
367375

368376
//////////////// BELOW ARE INTERFACES THAT SUPPORT THE ONES ABOVE ////////////////

tests/openai_chat_completion.test.ts

+21
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ describe('Check chat completion unsupported requests', () => {
6868
}).toThrow("If the request is stateful, `n` cannot be > 1.");
6969
});
7070

71+
test('Non-integer seed', () => {
72+
expect(() => {
73+
const request: ChatCompletionRequest = {
74+
messages: [
75+
{ role: "user", content: "Hello! " },
76+
],
77+
max_gen_len: 10,
78+
seed: 42.2, // Note that Number.isInteger(42.0) is true
79+
};
80+
postInitAndCheckFields(request)
81+
}).toThrow("`seed` should be an integer, but got");
82+
});
83+
7184
// Remove when we support image input (e.g. LlaVA model)
7285
test('Image input is unsupported', () => {
7386
expect(() => {
@@ -128,6 +141,14 @@ describe('Supported requests', () => {
128141
temperature: 1.5,
129142
max_gen_len: 25,
130143
frequency_penalty: 0.2,
144+
seed: 42,
145+
logprobs: true,
146+
top_logprobs: 2,
147+
logit_bias: {
148+
"13813": -100,
149+
"10319": 5,
150+
"7660": 5,
151+
},
131152
};
132153
postInitAndCheckFields(request)
133154
});

0 commit comments

Comments
 (0)