Skip to content

Commit 02c65f9

Browse files
dhuebnersdirix
authored andcommitted
Ollama LLM provider tools support #14610
1 parent d493ea9 commit 02c65f9

File tree

1 file changed

+134
-36
lines changed

1 file changed

+134
-36
lines changed

packages/ai-ollama/src/node/ollama-language-model.ts

+134-36
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import {
2020
LanguageModelRequest,
2121
LanguageModelRequestMessage,
2222
LanguageModelResponse,
23+
LanguageModelStreamResponse,
2324
LanguageModelStreamResponsePart,
25+
ToolCall,
2426
ToolRequest
2527
} from '@theia/ai-core';
2628
import { CancellationToken } from '@theia/core';
@@ -31,7 +33,9 @@ export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier');
3133
export class OllamaModel implements LanguageModel {
3234

3335
protected readonly DEFAULT_REQUEST_SETTINGS: Partial<Omit<ChatRequest, 'stream' | 'model'>> = {
34-
keep_alive: '15m'
36+
keep_alive: '15m',
37+
// options see: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
38+
options: {}
3539
};
3640

3741
readonly providerId = 'ollama';
@@ -50,62 +54,125 @@ export class OllamaModel implements LanguageModel {
5054
public defaultRequestSettings?: { [key: string]: unknown }
5155
) { }
5256

57+
async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
58+
const settings = this.getSettings(request);
59+
const ollama = this.initializeOllama();
60+
61+
const ollamaRequest: ExtendedChatRequest = {
62+
model: this.model,
63+
...this.DEFAULT_REQUEST_SETTINGS,
64+
...settings,
65+
messages: request.messages.map(this.toOllamaMessage),
66+
tools: request.tools?.map(this.toOllamaTool)
67+
};
68+
const structured = request.response_format?.type === 'json_schema';
69+
return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken);
70+
}
71+
72+
/**
73+
* Retrieves the settings for the chat request, merging the request-specific settings with the default settings.
74+
* @param request The language model request containing specific settings.
75+
* @returns A partial ChatRequest object containing the merged settings.
76+
*/
5377
protected getSettings(request: LanguageModelRequest): Partial<ChatRequest> {
5478
const settings = request.settings ?? this.defaultRequestSettings ?? {};
5579
return {
5680
options: settings as Partial<Options>
5781
};
5882
}
5983

60-
async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
61-
const settings = this.getSettings(request);
62-
const ollama = this.initializeOllama();
84+
protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise<LanguageModelResponse> {
85+
86+
// Handle structured output request
87+
if (structured) {
88+
return this.handleStructuredOutputRequest(ollama, ollamaRequest);
89+
}
6390

64-
if (request.response_format?.type === 'json_schema') {
65-
return this.handleStructuredOutputRequest(ollama, request);
91+
// Handle tool request - response may call tools
92+
if (ollamaRequest.tools && ollamaRequest.tools?.length > 0) {
93+
return this.handleToolsRequest(ollama, ollamaRequest);
6694
}
95+
96+
// Handle standard chat request
6797
const response = await ollama.chat({
68-
model: this.model,
69-
...this.DEFAULT_REQUEST_SETTINGS,
70-
...settings,
71-
messages: request.messages.map(this.toOllamaMessage),
72-
stream: true,
73-
tools: request.tools?.map(this.toOllamaTool),
98+
...ollamaRequest,
99+
stream: true
74100
});
101+
return this.handleCancellationAndWrapIterator(response, cancellation);
102+
}
75103

76-
cancellationToken?.onCancellationRequested(() => {
77-
response.abort();
104+
protected async handleToolsRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, prevResponse?: ChatResponse): Promise<LanguageModelResponse> {
105+
const response = prevResponse || await ollama.chat({
106+
...chatRequest,
107+
stream: false
78108
});
79-
80-
async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
81-
for await (const item of inputIterable) {
82-
// TODO handle tool calls
83-
yield { content: item.message.content };
109+
if (response.message.tool_calls) {
110+
const tools: ToolWithHandler[] = chatRequest.tools ?? [];
111+
// Add response message to chat history
112+
chatRequest.messages.push(response.message);
113+
const tool_calls: ToolCall[] = [];
114+
for (const [idx, toolCall] of response.message.tool_calls.entries()) {
115+
const functionToCall = tools.find(tool => tool.function.name === toolCall.function.name);
116+
if (functionToCall) {
117+
const args = JSON.stringify(toolCall.function?.arguments);
118+
const funcResult = await functionToCall.handler(args);
119+
chatRequest.messages.push({
120+
role: 'tool',
121+
content: `Tool call ${functionToCall.function.name} returned: ${String(funcResult)}`,
122+
});
123+
let resultString = String(funcResult);
124+
if (resultString.length > 1000) {
125+
// truncate result string if it is too long
126+
resultString = resultString.substring(0, 1000) + '...';
127+
}
128+
tool_calls.push({
129+
id: `ollama_${response.created_at}_${idx}`,
130+
function: {
131+
name: functionToCall.function.name,
132+
arguments: Object.values(toolCall.function?.arguments ?? {}).join(', ')
133+
},
134+
result: resultString,
135+
finished: true
136+
});
137+
}
138+
}
139+
// Get final response from model with function outputs
140+
const finalResponse = await ollama.chat({ ...chatRequest, stream: false });
141+
if (finalResponse.message.tool_calls) {
142+
// If the final response also calls tools, recursively handle them
143+
return this.handleToolsRequest(ollama, chatRequest, finalResponse);
84144
}
145+
return { stream: this.createAsyncIterable([{ tool_calls }, { content: finalResponse.message.content }]) };
85146
}
86-
return { stream: wrapAsyncIterator(response) };
147+
return { text: response.message.content };
87148
}
88149

89-
protected async handleStructuredOutputRequest(ollama: Ollama, request: LanguageModelRequest): Promise<LanguageModelParsedResponse> {
90-
const settings = this.getSettings(request);
91-
const result = await ollama.chat({
92-
...settings,
93-
...this.DEFAULT_REQUEST_SETTINGS,
94-
model: this.model,
95-
messages: request.messages.map(this.toOllamaMessage),
150+
protected createAsyncIterable<T>(items: T[]): AsyncIterable<T> {
151+
return {
152+
[Symbol.asyncIterator]: async function* (): AsyncIterableIterator<T> {
153+
for (const item of items) {
154+
yield item;
155+
}
156+
}
157+
};
158+
}
159+
160+
protected async handleStructuredOutputRequest(ollama: Ollama, chatRequest: ChatRequest): Promise<LanguageModelParsedResponse> {
161+
const response = await ollama.chat({
162+
...chatRequest,
96163
format: 'json',
97164
stream: false,
98165
});
99166
try {
100167
return {
101-
content: result.message.content,
102-
parsed: JSON.parse(result.message.content)
168+
content: response.message.content,
169+
parsed: JSON.parse(response.message.content)
103170
};
104171
} catch (error) {
105172
// TODO use ILogger
106173
console.log('Failed to parse structured response from the language model.', error);
107174
return {
108-
content: result.message.content,
175+
content: response.message.content,
109176
parsed: {}
110177
};
111178
}
@@ -119,11 +186,21 @@ export class OllamaModel implements LanguageModel {
119186
return new Ollama({ host: host });
120187
}
121188

122-
protected toOllamaTool(tool: ToolRequest): Tool {
123-
const transform = (props: Record<string, {
124-
[key: string]: unknown;
125-
type: string;
126-
}> | undefined) => {
189+
protected handleCancellationAndWrapIterator(response: AbortableAsyncIterable<ChatResponse>, token?: CancellationToken): LanguageModelStreamResponse {
190+
token?.onCancellationRequested(() => {
191+
// maybe it is better to use ollama.abort() as we are using one client per request
192+
response.abort();
193+
});
194+
async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
195+
for await (const item of inputIterable) {
196+
yield { content: item.message.content };
197+
}
198+
}
199+
return { stream: wrapAsyncIterator(response) };
200+
}
201+
202+
protected toOllamaTool(tool: ToolRequest): ToolWithHandler {
203+
const transform = (props: Record<string, { [key: string]: unknown; type: string; }> | undefined) => {
127204
if (!props) {
128205
return undefined;
129206
}
@@ -148,7 +225,8 @@ export class OllamaModel implements LanguageModel {
148225
required: Object.keys(tool.parameters?.properties ?? {}),
149226
properties: transform(tool.parameters?.properties) ?? {}
150227
},
151-
}
228+
},
229+
handler: tool.handler
152230
};
153231
}
154232

@@ -165,3 +243,23 @@ export class OllamaModel implements LanguageModel {
165243
return { role: 'system', content: '' };
166244
}
167245
}
246+
247+
/**
248+
* Extended Tool containing a handler
249+
* @see Tool
250+
*/
251+
type ToolWithHandler = Tool & { handler: (arg_string: string) => Promise<unknown> };
252+
253+
/**
254+
* Extended chat request with mandatory messages and ToolWithHandler tools
255+
*
256+
* @see ChatRequest
257+
* @see ToolWithHandler
258+
*/
259+
type ExtendedChatRequest = ChatRequest & {
260+
messages: Message[]
261+
tools?: ToolWithHandler[]
262+
};
263+
264+
// Ollama doesn't export this type, so we have to define it here
265+
type AbortableAsyncIterable<T> = AsyncIterable<T> & { abort: () => void };

0 commit comments

Comments
 (0)