From 39375bea5ff2df2e41d0a6d7925e861b965173fb Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 26 Feb 2025 21:28:34 +0100 Subject: [PATCH] Use weather tool helpers in tests --- Package.swift | 1 + Sources/SHLLM/Tools.swift | 22 +++-- Tests/SHLLMTests/Helpers.swift | 13 +++ Tests/SHLLMTests/Models/DeepSeekR1Tests.swift | 19 +--- .../SHLLMTests/Models/Qwen2_5-1_5BTests.swift | 96 +++---------------- 5 files changed, 41 insertions(+), 110 deletions(-) diff --git a/Package.swift b/Package.swift index 147780b..0ef67ca 100644 --- a/Package.swift +++ b/Package.swift @@ -41,6 +41,7 @@ let package = Package( // .copy("Resources/gemma-2-2b-it-4bit"), // .copy("Resources/Phi-3.5-mini-instruct-4bit"), // .copy("Resources/Qwen2.5-1.5B-Instruct-4bit"), +// .copy("Resources/Qwen2.5-7B-Instruct-4bit"), // ], linkerSettings: [ .linkedFramework("CoreGraphics", .when(platforms: [.macOS])), diff --git a/Sources/SHLLM/Tools.swift b/Sources/SHLLM/Tools.swift index 7607e61..7f77106 100644 --- a/Sources/SHLLM/Tools.swift +++ b/Sources/SHLLM/Tools.swift @@ -17,7 +17,11 @@ public struct ToolFunction { public let description: String? public let parameters: [ToolFunctionParameter] - init(name: String, description: String? = nil, parameters: [ToolFunctionParameter]) { + public init( + name: String, + description: String? = nil, + parameters: [ToolFunctionParameter] + ) { self.name = name self.description = description self.parameters = parameters @@ -54,7 +58,7 @@ public struct ToolFunctionParameter { public let description: String? public let required: Bool - static func string( + public static func string( name: String, description: String? = nil, required: Bool = false, @@ -78,7 +82,7 @@ public struct ToolFunctionParameter { ) } - static func number( + public static func number( name: String, description: String? = nil, required: Bool = false, @@ -100,7 +104,7 @@ public struct ToolFunctionParameter { ) } - static func integer( + public static func integer( name: String, description: String? = nil, required: Bool = false, @@ -122,7 +126,7 @@ public struct ToolFunctionParameter { ) } - static func array( + public static func array( name: String, description: String? = nil, required: Bool = false, @@ -136,7 +140,7 @@ public struct ToolFunctionParameter { ) } - static func object( + public static func object( name: String, description: String? = nil, required: Bool = false, @@ -150,7 +154,7 @@ public struct ToolFunctionParameter { ) } - static func boolean( + public static func boolean( name: String, description: String? = nil, required: Bool = false @@ -158,7 +162,7 @@ public struct ToolFunctionParameter { .init(name: name, type: .boolean, description: description, required: required) } - static func null( + public static func null( name: String, description: String? = nil, required: Bool = false @@ -166,7 +170,7 @@ public struct ToolFunctionParameter { .init(name: name, type: .null, description: description, required: required) } - init( + public init( name: String, type: ToolFunctionParameterType, description: String? = nil, diff --git a/Tests/SHLLMTests/Helpers.swift b/Tests/SHLLMTests/Helpers.swift index 0616117..c70239c 100644 --- a/Tests/SHLLMTests/Helpers.swift +++ b/Tests/SHLLMTests/Helpers.swift @@ -26,6 +26,19 @@ protocol InitializableWithDirectory { init(directory: URL) async throws } +let weatherToolFunction = ToolFunction( + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: [ + .string( + name: "location", + description: "The city and state, e.g. San Francisco, CA", + required: true + ), + .string(name: "unit", restrictTo: ["celsius", "fahrenheit"]), + ] +) + enum WeatherTool: Codable, Hashable { case getCurrentWeather(location: String, unit: WeatherUnit) diff --git a/Tests/SHLLMTests/Models/DeepSeekR1Tests.swift b/Tests/SHLLMTests/Models/DeepSeekR1Tests.swift index 6511fa1..c8df862 100644 --- a/Tests/SHLLMTests/Models/DeepSeekR1Tests.swift +++ b/Tests/SHLLMTests/Models/DeepSeekR1Tests.swift @@ -25,23 +25,8 @@ func canLoadAndQueryDeepSeekR1() async throws { func canHelpMeFetchTheWeatherWithR1() async throws { guard let llm = try await DeepSeekR1.tests else { return } - let tools = Tools([ - .init( - name: "get_current_weather", - description: "Get the current weather in a given location", - parameters: [ - .string( - name: "location", - description: "The city and state, e.g. San Francisco, CA", - required: true - ), - .string(name: "unit", restrictTo: ["celsius", "fahrenheit"]), - ] - ), - ]) - let tool1: WeatherTool = try await llm.request( - tools: tools, + tools: Tools([weatherToolFunction]), messages: [ [ "role": "system", @@ -59,7 +44,7 @@ func canHelpMeFetchTheWeatherWithR1() async throws { #expect(tool1 == expectedTool1) let tool2: WeatherTool = try await llm.request( - tools: tools, + tools: Tools([weatherToolFunction]), messages: [ [ "role": "system", diff --git a/Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift b/Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift index 3a21388..9bc85d6 100644 --- a/Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift +++ b/Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift @@ -25,86 +25,8 @@ func canLoadAndQueryQwen2_5__1_5B() async throws { func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws { guard let llm = try await Qwen2_5__1_5B.tests else { return } - enum Tool: Codable, Hashable { - case getCurrentWeather(location: String, unit: WeatherUnit) - - private enum CodingKeys: String, CodingKey { - case name - case arguments - } - - init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - let toolName = try container.decode(String.self, forKey: .name) - let arguments = try container.decode([String: String].self, forKey: .arguments) - - switch toolName { - case "get_current_weather": - guard let location = arguments["location"] else { - throw DecodingError.keyNotFound( - CodingKeys.arguments, - DecodingError.Context( - codingPath: [CodingKeys.arguments], - debugDescription: "Missing 'location' key" - ) - ) - } - guard let unitString = arguments["unit"], - let unit = WeatherUnit(rawValue: unitString) - else { - throw DecodingError.dataCorruptedError( - forKey: CodingKeys.arguments, - in: container, - debugDescription: "Missing or invalid 'unit' key" - ) - } - self = .getCurrentWeather(location: location, unit: unit) - - default: - throw DecodingError.dataCorruptedError( - forKey: CodingKeys.name, - in: container, - debugDescription: "Unrecognized tool name: \(toolName)" - ) - } - } - - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - switch self { - case let .getCurrentWeather(location, unit): - try container.encode("getCurrentWeather", forKey: .name) - let arguments: [String: String] = [ - "location": location, - "unit": unit.rawValue, - ] - try container.encode(arguments, forKey: .arguments) - } - } - } - - enum WeatherUnit: String, Codable { - case celsius - case fahrenheit - } - - let tools = Tools([ - .init( - name: "get_current_weather", - description: "Get the current weather in a given location", - parameters: [ - .string( - name: "location", - description: "The city and state, e.g. San Francisco, CA", - required: true - ), - .string(name: "unit", restrictTo: ["celsius", "fahrenheit"]), - ] - ), - ]) - - let tool1: Tool = try await llm.request( - tools: tools, + let tool1: WeatherTool = try await llm.request( + tools: Tools([weatherToolFunction]), messages: [ [ "role": "system", @@ -114,12 +36,15 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws { ] ) - let expectedTool1 = Tool.getCurrentWeather(location: "Paris, France", unit: .fahrenheit) + let expectedTool1 = WeatherTool.getCurrentWeather( + location: "Paris, France", + unit: .fahrenheit + ) #expect(tool1 == expectedTool1) - let tool2: Tool = try await llm.request( - tools: tools, + let tool2: WeatherTool = try await llm.request( + tools: Tools([weatherToolFunction]), messages: [ [ "role": "system", @@ -130,7 +55,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws { ] ) - let expectedTool2 = Tool.getCurrentWeather(location: "Paris, France", unit: .celsius) + let expectedTool2 = WeatherTool.getCurrentWeather( + location: "Paris, France", + unit: .celsius + ) #expect(tool2 == expectedTool2) }