Skip to content

Commit

Permalink
Convert get current weather arguments into a struct
Browse files Browse the repository at this point in the history
  • Loading branch information
atdrendel committed Feb 28, 2025
1 parent e42356d commit 2752fec
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 94 deletions.
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/shareup/mlx-swift-examples",
"state" : {
"revision" : "20701c0eeedd339ede4dd3b964152d814a3e9716",
"version" : "0.0.1"
"revision" : "0c50f71e1a3fcaaf09d1f93721a55e4f291bc442",
"version" : "0.0.2"
}
},
{
Expand Down
8 changes: 4 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "SHLLM",
platforms: [.iOS(.v16), .macOS(.v14)],
products: [
// Products define the executables and libraries a package produces, making them visible
// to other packages.
.library(
name: "SHLLM",
targets: ["SHLLM"]
Expand All @@ -17,7 +14,7 @@ let package = Package(
dependencies: [
.package(
url: "https://github.com/shareup/mlx-swift-examples",
from: "0.0.1"
from: "0.0.2"
),
.package(
url: "https://github.com/huggingface/swift-transformers",
Expand All @@ -39,7 +36,10 @@ let package = Package(
// resources: [
// .copy("Resources/DeepSeek-R1-Distill-Qwen-7B-4bit"),
// .copy("Resources/gemma-2-2b-it-4bit"),
// .copy("Resources/OpenELM-270M-Instruct"),
// .copy("Resources/Phi-3.5-mini-instruct-4bit"),
// .copy("Resources/Phi-3.5-MoE-instruct-4bit"),
// .copy("Resources/Qwen1.5-0.5B-Chat-4bit"),
// .copy("Resources/Qwen2.5-1.5B-Instruct-4bit"),
// .copy("Resources/Qwen2.5-7B-Instruct-4bit"),
// ],
Expand Down
38 changes: 0 additions & 38 deletions Sources/SHLLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,44 +240,6 @@ extension LLM {
}
}

// FROM: https://github.com/ml-explore/mlx-swift-examples/blob/20701c0eeedd339ede4dd3b964152d814a3e9716/Libraries/MLXLMCommon/Load.swift#L61
private func loadWeights(
modelDirectory: URL, model: LanguageModel,
quantization: BaseConfiguration.Quantization? = nil
) throws {
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil
)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}

// per-model cleanup
weights = model.sanitize(weights: weights)

// quantize if needed
if let quantization {
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, _ in
weights["\(path).scales"] != nil
}
}

// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
// NOTE: removed verify: [.all] becuase Qwen models are not ready for that verification yet
try model.update(parameters: parameters, verify: [])

eval(model)
}

private extension String {
func trimmingToolCallMarkup() -> String {
let prefix = "<tool_call>\n"
Expand Down
62 changes: 32 additions & 30 deletions Tests/SHLLMTests/Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ let weatherToolFunction = ToolFunction(
)

enum WeatherTool: Codable, CustomStringConvertible, Hashable {
case getCurrentWeather(location: String, unit: WeatherUnit)
case getCurrentWeather(GetCurrentWeatherArguments)

private enum CodingKeys: String, CodingKey {
case name
Expand All @@ -50,29 +50,14 @@ enum WeatherTool: Codable, CustomStringConvertible, Hashable {
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let toolName = try container.decode(String.self, forKey: .name)
let arguments = try container.decode([String: String].self, forKey: .arguments)

switch toolName {
case "get_current_weather":
guard let location = arguments["location"] else {
throw DecodingError.keyNotFound(
CodingKeys.arguments,
DecodingError.Context(
codingPath: [CodingKeys.arguments],
debugDescription: "Missing 'location' key"
)
)
}
guard let unitString = arguments["unit"],
let unit = WeatherUnit(rawValue: unitString)
else {
throw DecodingError.dataCorruptedError(
forKey: CodingKeys.arguments,
in: container,
debugDescription: "Missing or invalid 'unit' key"
)
}
self = .getCurrentWeather(location: location, unit: unit)
let args = try container.decode(
GetCurrentWeatherArguments.self,
forKey: .arguments
)
self = .getCurrentWeather(args)

default:
throw DecodingError.dataCorruptedError(
Expand All @@ -86,25 +71,42 @@ enum WeatherTool: Codable, CustomStringConvertible, Hashable {
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case let .getCurrentWeather(location, unit):
case let .getCurrentWeather(args):
try container.encode("getCurrentWeather", forKey: .name)
let arguments: [String: String] = [
"location": location,
"unit": unit.rawValue,
]
try container.encode(arguments, forKey: .arguments)
try container.encode(args, forKey: .arguments)
}
}

var description: String {
switch self {
case let .getCurrentWeather(location, unit):
"getCurrentWeather(location: \(location), unit: \(unit))"
case let .getCurrentWeather(args):
"getCurrentWeather(\(args))"
}
}
}

enum WeatherUnit: String, Codable {
struct GetCurrentWeatherArguments: Codable, CustomStringConvertible, Hashable, Sendable {
var location: String
var unit: WeatherUnit

init(location: String, unit: WeatherUnit) {
self.location = location
self.unit = unit
}

var description: String {
"'\(location)', '\(unit)'"
}
}

enum WeatherUnit: String, Codable, CustomStringConvertible {
case celsius
case fahrenheit

var description: String {
switch self {
case .celsius: "Celsius"
case .fahrenheit: "Fahrenheit"
}
}
}
8 changes: 4 additions & 4 deletions Tests/SHLLMTests/Models/DeepSeekR1Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func canHelpMeFetchTheWeatherWithR1() async throws {
]
)

let expectedTool1 = WeatherTool.getCurrentWeather(
let expectedTool1 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .fahrenheit
)
))

print("\(#function) 1:", tool1)
#expect(tool1 == expectedTool1)
Expand All @@ -56,10 +56,10 @@ func canHelpMeFetchTheWeatherWithR1() async throws {
]
)

let expectedTool2 = WeatherTool.getCurrentWeather(
let expectedTool2 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .celsius
)
))

print("\(#function) 2:", tool2)
#expect(tool2 == expectedTool2)
Expand Down
8 changes: 4 additions & 4 deletions Tests/SHLLMTests/Models/Gemma2-2BTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func canHelpMeFetchTheWeatherWithGemma2_2B() async throws {
]
)

let expectedTool1 = WeatherTool.getCurrentWeather(
let expectedTool1 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .fahrenheit
)
))

print("\(#function) 1:", tool1)
#expect(tool1 == expectedTool1)
Expand All @@ -55,10 +55,10 @@ func canHelpMeFetchTheWeatherWithGemma2_2B() async throws {
]
)

let expectedTool2 = WeatherTool.getCurrentWeather(
let expectedTool2 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .celsius
)
))

print("\(#function) 2:", tool2)
#expect(tool2 == expectedTool2)
Expand Down
8 changes: 4 additions & 4 deletions Tests/SHLLMTests/Models/Phi3Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func canHelpMeFetchTheWeatherWithPhi3() async throws {
]
)

let expectedTool1 = WeatherTool.getCurrentWeather(
let expectedTool1 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .fahrenheit
)
))

print("\(#function) 1:", tool1)
#expect(tool1 == expectedTool1)
Expand All @@ -55,10 +55,10 @@ func canHelpMeFetchTheWeatherWithPhi3() async throws {
]
)

let expectedTool2 = WeatherTool.getCurrentWeather(
let expectedTool2 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .celsius
)
))

print("\(#function) 2:", tool2)
#expect(tool2 == expectedTool2)
Expand Down
8 changes: 4 additions & 4 deletions Tests/SHLLMTests/Models/Qwen2_5-1_5BTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws {
]
)

let expectedTool1 = WeatherTool.getCurrentWeather(
let expectedTool1 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .fahrenheit
)
))

print("\(#function) 1:", tool1)
#expect(tool1 == expectedTool1)
Expand All @@ -56,10 +56,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__1_5B() async throws {
]
)

let expectedTool2 = WeatherTool.getCurrentWeather(
let expectedTool2 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .celsius
)
))

print("\(#function) 2:", tool2)
#expect(tool2 == expectedTool2)
Expand Down
8 changes: 4 additions & 4 deletions Tests/SHLLMTests/Models/Qwen2_5-7BTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__7B() async throws {
]
)

let expectedTool1 = WeatherTool.getCurrentWeather(
let expectedTool1 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .fahrenheit
)
))

print("\(#function) 1:", tool1)
#expect(tool1 == expectedTool1)
Expand All @@ -55,10 +55,10 @@ func canHelpMeFetchTheWeatherWithQwen2_5__7B() async throws {
]
)

let expectedTool2 = WeatherTool.getCurrentWeather(
let expectedTool2 = WeatherTool.getCurrentWeather(.init(
location: "Paris, France",
unit: .celsius
)
))

print("\(#function) 2:", tool2)
#expect(tool2 == expectedTool2)
Expand Down

0 comments on commit 2752fec

Please sign in to comment.