Skip to content

Commit

Permalink
support anthropic thinking
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Feb 25, 2025
1 parent 35b3307 commit 34e14a5
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 19 deletions.
21 changes: 19 additions & 2 deletions src/libs/agent-runtime/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,29 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
}

private async buildAnthropicPayload(payload: ChatStreamPayload) {
const { messages, model, max_tokens = 4096, temperature, top_p, tools } = payload;
const { messages, model, max_tokens, temperature, top_p, tools, thinking } = payload;
const system_message = messages.find((m) => m.role === 'system');
const user_messages = messages.filter((m) => m.role !== 'system');

if (!!thinking) {
const maxTokens =
max_tokens ?? (thinking?.budget_tokens ? thinking?.budget_tokens + 4096 : 4096);

// `temperature` may only be set to 1 when thinking is enabled.
// `top_p` must be unset when thinking is enabled.
return {
max_tokens: maxTokens,
messages: await buildAnthropicMessages(user_messages),
model,
system: system_message?.content as string,

thinking,
tools: buildAnthropicTools(tools),
} satisfies Anthropic.MessageCreateParams;
}

return {
max_tokens,
max_tokens: max_tokens ?? 4096,
messages: await buildAnthropicMessages(user_messages),
model,
system: system_message?.content as string,
Expand Down
8 changes: 7 additions & 1 deletion src/libs/agent-runtime/types/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ export interface ChatStreamPayload {
* @default 1
*/
temperature: number;
/**
* use for Claude
*/
thinking?: {
budget_tokens: number;
type: 'enabled' | 'disabled';
};
tool_choice?: string;

tools?: ChatCompletionTool[];
/**
* @title 控制生成文本中最高概率的单个令牌
Expand Down
126 changes: 126 additions & 0 deletions src/libs/agent-runtime/utils/streams/anthropic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,132 @@ describe('AnthropicStream', () => {
expect(onToolCallMock).toHaveBeenCalledTimes(6);
});

it('should handle thinking ', async () => {
const streams = [
{
type: 'message_start',
message: {
id: 'msg_01MNsLe7n1uVLtu6W8rCFujD',
type: 'message',
role: 'assistant',
model: 'claude-3-7-sonnet-20250219',
content: [],
stop_reason: null,
stop_sequence: null,
usage: {
input_tokens: 46,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: 11,
},
},
},
{
type: 'content_block_start',
index: 0,
content_block: { type: 'thinking', thinking: '', signature: '' },
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'thinking_delta', thinking: '我需要比较两个数字的' },
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'thinking_delta', thinking: '大小:9.8和9' },
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'thinking_delta', thinking: '11\n\n所以9.8比9.11大。' },
},
{
type: 'content_block_delta',
index: 0,
delta: {
type: 'signature_delta',
signature:
'EuYBCkQYAiJAHnHRJG4nPBrdTlo6CmXoyE8WYoQeoPiLnXaeuaM8ExdiIEkVvxK1DYXOz5sCubs2s/G1NsST8A003Zb8XmuhYBIMwDGMZSZ3+gxOEBpVGgzdpOlDNBTxke31SngiMKUk6WcSiA11OSVBuInNukoAhnRd5jPAEg7e5mIoz/qJwnQHV8I+heKUreP77eJdFipQaM3FHn+avEHuLa/Z/fu0O9BftDi+caB1UWDwJakNeWX1yYTvK+N1v4gRpKbj4AhctfYHMjq8qX9XTnXme5AGzCYC6HgYw2/RfalWzwNxI6k=',
},
},
{ type: 'content_block_stop', index: 0 },
{ type: 'content_block_start', index: 1, content_block: { type: 'text', text: '' } },
{
type: 'content_block_delta',
index: 1,
delta: { type: 'text_delta', text: '9.8比9.11大。' },
},
{ type: 'content_block_stop', index: 1 },
{
type: 'message_delta',
delta: { stop_reason: 'end_turn', stop_sequence: null },
usage: { output_tokens: 354 },
},
{ type: 'message_stop' },
];

const mockReadableStream = new ReadableStream({
start(controller) {
streams.forEach((chunk) => {
controller.enqueue(chunk);
});
controller.close();
},
});

const protocolStream = AnthropicStream(mockReadableStream);

const decoder = new TextDecoder();
const chunks = [];

// @ts-ignore
for await (const chunk of protocolStream) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual(
[
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: data',
'data: {"id":"msg_01MNsLe7n1uVLtu6W8rCFujD","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":46,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":11}}\n',
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: reasoning',
'data: ""\n',
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: reasoning',
'data: "我需要比较两个数字的"\n',
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: reasoning',
'data: "大小:9.8和9"\n',
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: reasoning',
'data: "11\\n\\n所以9.8比9.11大。"\n',
// Tool calls
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: reasoning_signature',
`data: "EuYBCkQYAiJAHnHRJG4nPBrdTlo6CmXoyE8WYoQeoPiLnXaeuaM8ExdiIEkVvxK1DYXOz5sCubs2s/G1NsST8A003Zb8XmuhYBIMwDGMZSZ3+gxOEBpVGgzdpOlDNBTxke31SngiMKUk6WcSiA11OSVBuInNukoAhnRd5jPAEg7e5mIoz/qJwnQHV8I+heKUreP77eJdFipQaM3FHn+avEHuLa/Z/fu0O9BftDi+caB1UWDwJakNeWX1yYTvK+N1v4gRpKbj4AhctfYHMjq8qX9XTnXme5AGzCYC6HgYw2/RfalWzwNxI6k="\n`,
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: data',
`data: {"type":"content_block_stop","index":0}\n`,
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: data',
`data: ""\n`,
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: text',
`data: "9.8比9.11大。"\n`,
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: data',
`data: {"type":"content_block_stop","index":1}\n`,
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: stop',
'data: "end_turn"\n',
'id: msg_01MNsLe7n1uVLtu6W8rCFujD',
'event: stop',
'data: "message_stop"\n',
].map((item) => `${item}\n`),
);
});
it('should handle ReadableStream input', async () => {
const mockReadableStream = new ReadableStream({
start(controller) {
Expand Down
62 changes: 46 additions & 16 deletions src/libs/agent-runtime/utils/streams/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ import {

export const transformAnthropicStream = (
chunk: Anthropic.MessageStreamEvent,
stack: StreamContext,
context: StreamContext,
): StreamProtocolChunk => {
// maybe need another structure to add support for multiple choices
switch (chunk.type) {
case 'message_start': {
stack.id = chunk.message.id;
context.id = chunk.message.id;
return { data: chunk.message, id: chunk.message.id, type: 'data' };
}
case 'content_block_start': {
if (chunk.content_block.type === 'tool_use') {
const toolChunk = chunk.content_block;

// if toolIndex is not defined, set it to 0
if (typeof stack.toolIndex === 'undefined') {
stack.toolIndex = 0;
if (typeof context.toolIndex === 'undefined') {
context.toolIndex = 0;
}
// if toolIndex is defined, increment it
else {
stack.toolIndex += 1;
context.toolIndex += 1;
}

const toolCall: StreamToolCallChunkData = {
Expand All @@ -41,57 +41,87 @@ export const transformAnthropicStream = (
name: toolChunk.name,
},
id: toolChunk.id,
index: stack.toolIndex,
index: context.toolIndex,
type: 'function',
};

stack.tool = { id: toolChunk.id, index: stack.toolIndex, name: toolChunk.name };
context.tool = { id: toolChunk.id, index: context.toolIndex, name: toolChunk.name };

return { data: [toolCall], id: stack.id, type: 'tool_calls' };
return { data: [toolCall], id: context.id, type: 'tool_calls' };
}

return { data: chunk.content_block.text, id: stack.id, type: 'data' };
if (chunk.content_block.type === 'thinking') {
const thinkingChunk = chunk.content_block;

return { data: thinkingChunk.thinking, id: context.id, type: 'reasoning' };
}

if (chunk.content_block.type === 'redacted_thinking') {
return {
data: chunk.content_block.data,
id: context.id,
type: 'reasoning',
};
}

return { data: chunk.content_block.text, id: context.id, type: 'data' };
}

case 'content_block_delta': {
switch (chunk.delta.type) {
case 'text_delta': {
return { data: chunk.delta.text, id: stack.id, type: 'text' };
return { data: chunk.delta.text, id: context.id, type: 'text' };
}

case 'input_json_delta': {
const delta = chunk.delta.partial_json;

const toolCall: StreamToolCallChunkData = {
function: { arguments: delta },
index: stack.toolIndex || 0,
index: context.toolIndex || 0,
type: 'function',
};

return {
data: [toolCall],
id: stack.id,
id: context.id,
type: 'tool_calls',
} as StreamProtocolToolCallChunk;
}

case 'signature_delta': {
return {
data: chunk.delta.signature,
id: context.id,
type: 'reasoning_signature' as any,
};
}

case 'thinking_delta': {
return {
data: chunk.delta.thinking,
id: context.id,
type: 'reasoning',
};
}

default: {
break;
}
}
return { data: chunk, id: stack.id, type: 'data' };
return { data: chunk, id: context.id, type: 'data' };
}

case 'message_delta': {
return { data: chunk.delta.stop_reason, id: stack.id, type: 'stop' };
return { data: chunk.delta.stop_reason, id: context.id, type: 'stop' };
}

case 'message_stop': {
return { data: 'message_stop', id: stack.id, type: 'stop' };
return { data: 'message_stop', id: context.id, type: 'stop' };
}

default: {
return { data: chunk, id: stack.id, type: 'data' };
return { data: chunk, id: context.id, type: 'data' };
}
}
};
Expand Down
4 changes: 4 additions & 0 deletions src/libs/agent-runtime/utils/streams/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ export interface StreamContext {
* this flag is used to check if the pplx citation is returned,and then not return it again
*/
returnedPplxCitation?: boolean;
thinking?: {
id: string;
name: string;
};
tool?: {
id: string;
index: number;
Expand Down

0 comments on commit 34e14a5

Please sign in to comment.