Skip to content

Commit

Permalink
Use weather tool helpers in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atdrendel committed Feb 26, 2025
1 parent f91ea6f commit 39375be
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 110 deletions.
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ let package = Package(
// .copy("Resources/gemma-2-2b-it-4bit"),
// .copy("Resources/Phi-3.5-mini-instruct-4bit"),
// .copy("Resources/Qwen2.5-1.5B-Instruct-4bit"),
// .copy("Resources/Qwen2.5-7B-Instruct-4bit"),
// ],
linkerSettings: [
.linkedFramework("CoreGraphics", .when(platforms: [.macOS])),
Expand Down
22 changes: 13 additions & 9 deletions Sources/SHLLM/Tools.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ public struct ToolFunction {
public let description: String?
public let parameters: [ToolFunctionParameter]

init(name: String, description: String? = nil, parameters: [ToolFunctionParameter]) {
public init(
name: String,
description: String? = nil,
parameters: [ToolFunctionParameter]
) {
self.name = name
self.description = description
self.parameters = parameters
Expand Down Expand Up @@ -54,7 +58,7 @@ public struct ToolFunctionParameter {
public let description: String?
public let required: Bool

static func string(
public static func string(
name: String,
description: String? = nil,
required: Bool = false,
Expand All @@ -78,7 +82,7 @@ public struct ToolFunctionParameter {
)
}

static func number(
public static func number(
name: String,
description: String? = nil,
required: Bool = false,
Expand All @@ -100,7 +104,7 @@ public struct ToolFunctionParameter {
)
}

static func integer(
public static func integer(
name: String,
description: String? = nil,
required: Bool = false,
Expand All @@ -122,7 +126,7 @@ public struct ToolFunctionParameter {
)
}

static func array(
public static func array(
name: String,
description: String? = nil,
required: Bool = false,
Expand All @@ -136,7 +140,7 @@ public struct ToolFunctionParameter {
)
}

static func object(
public static func object(
name: String,
description: String? = nil,
required: Bool = false,
Expand All @@ -150,23 +154,23 @@ public struct ToolFunctionParameter {
)
}

static func boolean(
public static func boolean(
name: String,
description: String? = nil,
required: Bool = false
) -> ToolFunctionParameter {
.init(name: name, type: .boolean, description: description, required: required)
}

static func null(
public static func null(
name: String,
description: String? = nil,
required: Bool = false
) -> ToolFunctionParameter {
.init(name: name, type: .null, description: description, required: required)
}

init(
public init(
name: String,
type: ToolFunctionParameterType,
description: String? = nil,
Expand Down
13 changes: 13 additions & 0 deletions Tests/SHLLMTests/Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ protocol InitializableWithDirectory {
init(directory: URL) async throws
}

let weatherToolFunction = ToolFunction(
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: [
.string(
name: "location",
description: "The city and state, e.g. San Francisco, CA",
required: true
),
.string(name: "unit", restrictTo: ["celsius", "fahrenheit"]),
]
)

enum WeatherTool: Codable, Hashable {
case getCurrentWeather(location: String, unit: WeatherUnit)

Expand Down
19 changes: 2 additions & 17 deletions Tests/SHLLMTests/Models/DeepSeekR1Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,8 @@ func canLoadAndQueryDeepSeekR1() async throws {
func canHelpMeFetchTheWeatherWithR1() async throws {
guard let llm = try await DeepSeekR1.tests else { return }

let tools = Tools([
.init(
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: [
.string(
name: "location",
description: "The city and state, e.g. San Francisco, CA",
required: true
),
.string(name: "unit", restrictTo: ["celsius", "fahrenheit"]),
]
),
])

let tool1: WeatherTool = try await llm.request(
tools: tools,
tools: Tools([weatherToolFunction]),
messages: [
[
"role": "system",
Expand All @@ -59,7 +44,7 @@ func canHelpMeFetchTheWeatherWithR1() async throws {
#expect(tool1 == expectedTool1)

let tool2: WeatherTool = try await llm.request(
tools: tools,
tools: Tools([weatherToolFunction]),
messages: [
[
"role": "system",
Expand Down
96 changes: 12 additions & 84 deletions Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,86 +25,8 @@ func canLoadAndQueryQwen2_5__1_5B() async throws {
func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws {
guard let llm = try await Qwen2_5__1_5B.tests else { return }

enum Tool: Codable, Hashable {
case getCurrentWeather(location: String, unit: WeatherUnit)

private enum CodingKeys: String, CodingKey {
case name
case arguments
}

init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let toolName = try container.decode(String.self, forKey: .name)
let arguments = try container.decode([String: String].self, forKey: .arguments)

switch toolName {
case "get_current_weather":
guard let location = arguments["location"] else {
throw DecodingError.keyNotFound(
CodingKeys.arguments,
DecodingError.Context(
codingPath: [CodingKeys.arguments],
debugDescription: "Missing 'location' key"
)
)
}
guard let unitString = arguments["unit"],
let unit = WeatherUnit(rawValue: unitString)
else {
throw DecodingError.dataCorruptedError(
forKey: CodingKeys.arguments,
in: container,
debugDescription: "Missing or invalid 'unit' key"
)
}
self = .getCurrentWeather(location: location, unit: unit)

default:
throw DecodingError.dataCorruptedError(
forKey: CodingKeys.name,
in: container,
debugDescription: "Unrecognized tool name: \(toolName)"
)
}
}

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case let .getCurrentWeather(location, unit):
try container.encode("getCurrentWeather", forKey: .name)
let arguments: [String: String] = [
"location": location,
"unit": unit.rawValue,
]
try container.encode(arguments, forKey: .arguments)
}
}
}

enum WeatherUnit: String, Codable {
case celsius
case fahrenheit
}

let tools = Tools([
.init(
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: [
.string(
name: "location",
description: "The city and state, e.g. San Francisco, CA",
required: true
),
.string(name: "unit", restrictTo: ["celsius", "fahrenheit"]),
]
),
])

let tool1: Tool = try await llm.request(
tools: tools,
let tool1: WeatherTool = try await llm.request(
tools: Tools([weatherToolFunction]),
messages: [
[
"role": "system",
Expand All @@ -114,12 +36,15 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws {
]
)

let expectedTool1 = Tool.getCurrentWeather(location: "Paris, France", unit: .fahrenheit)
let expectedTool1 = WeatherTool.getCurrentWeather(
location: "Paris, France",
unit: .fahrenheit
)

#expect(tool1 == expectedTool1)

let tool2: Tool = try await llm.request(
tools: tools,
let tool2: WeatherTool = try await llm.request(
tools: Tools([weatherToolFunction]),
messages: [
[
"role": "system",
Expand All @@ -130,7 +55,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws {
]
)

let expectedTool2 = Tool.getCurrentWeather(location: "Paris, France", unit: .celsius)
let expectedTool2 = WeatherTool.getCurrentWeather(
location: "Paris, France",
unit: .celsius
)

#expect(tool2 == expectedTool2)
}

0 comments on commit 39375be

Please sign in to comment.