diff --git a/package.json b/package.json index 2b6eb64..5ee70a8 100644 --- a/package.json +++ b/package.json @@ -102,7 +102,7 @@ "max_tokens": 4096 } ], - "description": "The template for AI completion responses. See:" + "description": "The template for AI completion responses." }, "MarkdownPaste.openaiCompletionTemplateFile": { "type": "string", @@ -337,7 +337,7 @@ "shelljs": "^0.8.5", "turndown": "^7.1.2", "xclip": "^1.0.5", - "openai": "^4.61.0" + "openai": "^4.61.0" }, "devDependencies": { "@types/glob": "^7.1.3", diff --git a/src/ToolsManager.ts b/src/ToolsManager.ts new file mode 100644 index 0000000..f1237c0 --- /dev/null +++ b/src/ToolsManager.ts @@ -0,0 +1,74 @@ +import { ChatCompletionTool } from "openai/resources/chat/completions"; +import Logger from "./Logger"; + +type ToolFunction = (...args: any[]) => any; + +interface ToolInfo { + func: ToolFunction; + description: string; + parameters: Record; +} + +export class ToolsManager { + private tools: Map; + + constructor() { + this.tools = new Map(); + } + + public registerDefaultTools() { + this.registerTool( + "get_current_weather", + ({ city }: { city: string }) => { + return JSON.stringify({ + city: city, + temperature: "25°C", + weather: "sunny", + }); + }, + "Get the current weather for a specified city", + { + type: "object", + properties: { + city: { type: "string", description: "The name of the city" }, + }, + required: ["city"], + } + ); + } + + public registerTool( + name: string, + func: ToolFunction, + description: string, + parameters: Record + ) { + this.tools.set(name, { func, description, parameters }); + } + + public executeTool(name: string, args: any): string | null { + const toolInfo = this.tools.get(name); + if (toolInfo) { + try { + return JSON.stringify(toolInfo.func(args)); + } catch (error) { + Logger.log(`Error executing tool ${name}:`, error); + return null; + } + } else { + Logger.log(`Tool ${name} not found`); + return null; + } + } + + public getToolsForOpenAI(): ChatCompletionTool[] { + return Array.from(this.tools.entries()).map(([toolName, toolInfo]) => ({ + type: "function", + function: { + name: toolName, + description: toolInfo.description, + parameters: toolInfo.parameters, + }, + })); + } +} diff --git a/src/ai_paster.ts b/src/ai_paster.ts index 83d231c..014018f 100644 --- a/src/ai_paster.ts +++ b/src/ai_paster.ts @@ -3,17 +3,19 @@ import OpenAI from "openai"; import { ChatCompletionMessageParam, ChatCompletionTool, - ChatCompletionToolMessageParam, } from "openai/resources/chat/completions"; import { Predefine } from "./predefine"; - import Logger from "./Logger"; +import { ToolsManager } from "./ToolsManager"; export class AIPaster { private client: OpenAI; + private toolsManager: ToolsManager; constructor() { this.client = new OpenAI(this.config.openaiConnectOption); + this.toolsManager = new ToolsManager(); + this.toolsManager.registerDefaultTools(); } public destructor() { @@ -42,27 +44,19 @@ export class AIPaster { const responseMessages = chatCompletion.choices[0].message; const toolCalls = chatCompletion.choices[0].message.tool_calls; if (toolCalls) { - const availableFunctions = { - get_current_weather: function ({ city }: { city: string }) { - return JSON.stringify({ - city: city, - temperature: "25°C", - weather: "sunny", - }); - }, - }; - // messages.push(responseMessages); for (const toolCall of toolCalls) { - const functionName: keyof typeof availableFunctions = toolCall - .function.name as keyof typeof availableFunctions; - const functionToCall = availableFunctions[functionName]; - const functionArgs = JSON.parse(toolCall.function.arguments); - const functionResponse = functionToCall(functionArgs); - completion.messages.push({ - tool_call_id: toolCall.id, - role: "tool", - content: functionResponse, - }); + const functionName = toolCall.function.name; + const functionResponse = this.toolsManager.executeTool( + functionName, + JSON.parse(toolCall.function.arguments) + ); + if (functionResponse !== null) { + completion.messages.push({ + tool_call_id: toolCall.id, + role: "tool", + content: functionResponse, + }); + } } completion.messages.forEach((message: ChatCompletionMessageParam) => { Logger.log( @@ -89,6 +83,15 @@ export class AIPaster { } } + private mergeToolsByFunctionName(existingTools, newTools) { + const toolMap = new Map(); + + existingTools.forEach((tool) => toolMap.set(tool.function.name, tool)); + newTools.forEach((tool) => toolMap.set(tool.function.name, tool)); + + return Array.from(toolMap.values()); + } + public async callAI(clipboardText: string): Promise { try { let openaiCompletionTemplate = this.config.openaiCompletionTemplate; @@ -124,6 +127,14 @@ export class AIPaster { ); } }); + if (completion.tools && Array.isArray(completion.tools)) { + completion.tools = this.mergeToolsByFunctionName( + completion.tools, + this.toolsManager.getToolsForOpenAI() + ); + } else { + completion.tools = this.toolsManager.getToolsForOpenAI(); + } let content = await this.runCompletion(completion); Logger.log("content:", content); result += content; diff --git a/test/suite/ToolsManager.test.ts b/test/suite/ToolsManager.test.ts new file mode 100644 index 0000000..4ac8792 --- /dev/null +++ b/test/suite/ToolsManager.test.ts @@ -0,0 +1,173 @@ +import * as assert from "assert"; +import { ToolsManager } from "../../src/ToolsManager"; +import { ChatCompletionTool } from "openai/resources/chat/completions"; + +// Defines a Mocha test suite to group tests of similar kind together +suite("ToolsManager Tests", () => { + let toolsManager: ToolsManager; + + setup(() => { + toolsManager = new ToolsManager(); + }); + + test("registerTool should add a new tool", () => { + const toolName = "test_tool"; + const toolFunc = () => ({ result: "success" }); + const toolDescription = "A test tool"; + const toolParameters = { type: "object", properties: {} }; + + toolsManager.registerTool( + toolName, + toolFunc, + toolDescription, + toolParameters + ); + + const tools = toolsManager.getToolsForOpenAI(); + assert.strictEqual(tools.length, 1); + assert.strictEqual(tools[0].function.name, toolName); + assert.strictEqual(tools[0].function.description, toolDescription); + assert.deepStrictEqual(tools[0].function.parameters, toolParameters); + }); + + test("executeTool should call the registered tool function", () => { + const toolName = "test_tool"; + const toolFunc = (args: any) => ({ result: args.input }); + const toolDescription = "A test tool"; + const toolParameters = { type: "object", properties: {} }; + + toolsManager.registerTool( + toolName, + toolFunc, + toolDescription, + toolParameters + ); + + const result = toolsManager.executeTool(toolName, { input: "test" }); + assert.strictEqual(result, JSON.stringify({ result: "test" })); + }); + + test("executeTool should return null for unregistered tool", () => { + const result = toolsManager.executeTool("nonexistent_tool", {}); + assert.strictEqual(result, null); + }); + + test("getToolsForOpenAI should return correct format", () => { + const toolName = "test_tool"; + const toolFunc = () => ({}); + const toolDescription = "A test tool"; + const toolParameters = { + type: "object", + properties: { arg: { type: "string" } }, + }; + + toolsManager.registerTool( + toolName, + toolFunc, + toolDescription, + toolParameters + ); + + const tools = toolsManager.getToolsForOpenAI(); + assert.strictEqual(tools.length, 1); + assert.deepStrictEqual(tools[0], { + type: "function", + function: { + name: toolName, + description: toolDescription, + parameters: toolParameters, + }, + }); + }); + + test("registerDefaultTools should register the weather tool", () => { + toolsManager.registerDefaultTools(); + const tools = toolsManager.getToolsForOpenAI(); + assert.strictEqual(tools.length, 1); + assert.strictEqual(tools[0].function.name, "get_current_weather"); + }); + + test("registerTool should overwrite existing tool with same name", () => { + const toolName = "test_tool"; + const toolFunc1 = () => ({ result: "original" }); + const toolFunc2 = () => ({ result: "overwritten" }); + const toolDescription = "A test tool"; + const toolParameters = { type: "object", properties: {} }; + + toolsManager.registerTool( + toolName, + toolFunc1, + toolDescription, + toolParameters + ); + toolsManager.registerTool( + toolName, + toolFunc2, + toolDescription, + toolParameters + ); + + const result = toolsManager.executeTool(toolName, {}); + assert.strictEqual(result, JSON.stringify({ result: "overwritten" })); + }); + + test("executeTool should handle errors in tool function", () => { + const toolName = "error_tool"; + const toolFunc = () => { + throw new Error("Test error"); + }; + const toolDescription = "A tool that throws an error"; + const toolParameters = { type: "object", properties: {} }; + + toolsManager.registerTool( + toolName, + toolFunc, + toolDescription, + toolParameters + ); + + const result = toolsManager.executeTool(toolName, {}); + assert.strictEqual(result, null); + }); + + test("getToolsForOpenAI should return empty array when no tools are registered", () => { + const tools = toolsManager.getToolsForOpenAI(); + assert.strictEqual(tools.length, 0); + }); + + test("registerTool should handle complex parameter structures", () => { + const toolName = "complex_tool"; + const toolFunc = () => ({}); + const toolDescription = "A tool with complex parameters"; + const toolParameters = { + type: "object", + properties: { + stringArg: { type: "string" }, + numberArg: { type: "number" }, + booleanArg: { type: "boolean" }, + arrayArg: { + type: "array", + items: { type: "string" }, + }, + objectArg: { + type: "object", + properties: { + nestedProp: { type: "string" }, + }, + }, + }, + required: ["stringArg", "numberArg"], + }; + + toolsManager.registerTool( + toolName, + toolFunc, + toolDescription, + toolParameters + ); + + const tools = toolsManager.getToolsForOpenAI(); + assert.strictEqual(tools.length, 1); + assert.deepStrictEqual(tools[0].function.parameters, toolParameters); + }); +});