Skip to content

Commit

Permalink
Our own Tools struct (#5)
Browse files Browse the repository at this point in the history
This should make it easier to specify tools.
  • Loading branch information
myobie authored Feb 28, 2025
1 parent d8a4fdc commit c3d3717
Show file tree
Hide file tree
Showing 11 changed files with 639 additions and 118 deletions.
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/shareup/mlx-swift-examples",
"state" : {
"revision" : "20701c0eeedd339ede4dd3b964152d814a3e9716",
"version" : "0.0.1"
"revision" : "0c50f71e1a3fcaaf09d1f93721a55e4f291bc442",
"version" : "0.0.2"
}
},
{
Expand Down
11 changes: 7 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "SHLLM",
platforms: [.iOS(.v16), .macOS(.v14)],
products: [
// Products define the executables and libraries a package produces, making them visible
// to other packages.
.library(
name: "SHLLM",
targets: ["SHLLM"]
Expand All @@ -17,7 +14,7 @@ let package = Package(
dependencies: [
.package(
url: "https://github.com/shareup/mlx-swift-examples",
from: "0.0.1"
from: "0.0.2"
),
.package(
url: "https://github.com/huggingface/swift-transformers",
Expand All @@ -37,8 +34,14 @@ let package = Package(
),
],
// resources: [
// .copy("Resources/DeepSeek-R1-Distill-Qwen-7B-4bit"),
// .copy("Resources/gemma-2-2b-it-4bit"),
// .copy("Resources/OpenELM-270M-Instruct"),
// .copy("Resources/Phi-3.5-mini-instruct-4bit"),
// .copy("Resources/Phi-3.5-MoE-instruct-4bit"),
// .copy("Resources/Qwen1.5-0.5B-Chat-4bit"),
// .copy("Resources/Qwen2.5-1.5B-Instruct-4bit"),
// .copy("Resources/Qwen2.5-7B-Instruct-4bit"),
// ],
linkerSettings: [
.linkedFramework("CoreGraphics", .when(platforms: [.macOS])),
Expand Down
67 changes: 33 additions & 34 deletions Sources/SHLLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,31 @@ public final class LLM {
}

extension LLM {
func request<T: Codable>(
tools: Tools,
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> T {
let result = try await request(
.init(messages: messages, tools: tools.toSpec()),
maxTokenCount: maxTokenCount
)

let decoder = JSONDecoder()

return try decoder.decode(
T.self,
from: Data(result.trimmingToolCallMarkup().utf8)
)
}

func request(
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> String {
try await request(.init(messages: messages), maxTokenCount: maxTokenCount)
}

func request(
_ input: UserInput,
maxTokenCount: Int = 1024 * 1024
Expand All @@ -215,40 +240,14 @@ extension LLM {
}
}

// FROM: https://github.com/ml-explore/mlx-swift-examples/blob/20701c0eeedd339ede4dd3b964152d814a3e9716/Libraries/MLXLMCommon/Load.swift#L61
private func loadWeights(
modelDirectory: URL, model: LanguageModel,
quantization: BaseConfiguration.Quantization? = nil
) throws {
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil
)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}

// per-model cleanup
weights = model.sanitize(weights: weights)
private extension String {
func trimmingToolCallMarkup() -> String {
let prefix = "<tool_call>\n"
let suffix = "\n</tool_call>"

// quantize if needed
if let quantization {
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, _ in
weights["\(path).scales"] != nil
}
var copy = trimmingCharacters(in: .whitespacesAndNewlines)
copy.removeFirst(prefix.count)
copy.removeLast(suffix.count)
return copy
}

// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
// NOTE: removed verify: [.all] becuase Qwen models are not ready for that verification yet
try model.update(parameters: parameters, verify: [])

eval(model)
}
23 changes: 23 additions & 0 deletions Sources/SHLLM/ModelProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@ public protocol ModelProtocol {
}

public extension ModelProtocol {
func request<T: Codable>(
tools: Tools,
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> T {
try await llm.withLock { llm in
try await llm.request(
tools: tools,
messages: messages,
maxTokenCount: maxTokenCount
)
}
}

func request(
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> String {
try await llm.withLock { llm in
try await llm.request(messages: messages, maxTokenCount: maxTokenCount)
}
}

func request(
_ input: UserInput,
maxTokenCount: Int = 1024 * 1024
Expand Down
Loading

0 comments on commit c3d3717

Please sign in to comment.